Skip to content

Commit

Permalink
fix(mssql): fix temporary table creation and implement cache (#9434)
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth authored Jun 24, 2024
1 parent f3cd8b2 commit 196d8a1
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 14 deletions.
85 changes: 73 additions & 12 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def do_connect(
def get_schema(
self, name: str, *, catalog: str | None = None, database: str | None = None
) -> sch.Schema:
# TODO: this is brittle and should be improved. We want to be able to
# identify if a given table is a temp table and update the search
# location accordingly.
if name.startswith("ibis_cache_"):
catalog, database = ("tempdb", "dbo")
name = "##" + name
conditions = [sg.column("table_name").eq(sge.convert(name))]

if database is not None:
Expand Down Expand Up @@ -481,20 +487,56 @@ def create_table(
temp: bool = False,
overwrite: bool = False,
) -> ir.Table:
"""Create a new table.
Parameters
----------
name
Name of the new table.
obj
An Ibis table expression or pandas table that will be used to
extract the schema and the data of the new table. If not provided,
`schema` must be given.
schema
The schema for the new table. Only one of `schema` or `obj` can be
provided.
database
Name of the database where the table will be created, if not the
default.
To specify a location in a separate catalog, you can pass in the
catalog and database as a string `"catalog.database"`, or as a tuple of
strings `("catalog", "database")`.
temp
Whether a table is temporary or not.
All created temp tables are "Global Temporary Tables". They will be
created in "tempdb.dbo" and will be prefixed with "##".
overwrite
Whether to clobber existing data.
`overwrite` and `temp` cannot be used together with MSSQL.
Returns
-------
Table
The table that was created.
"""
if obj is None and schema is None:
raise ValueError("Either `obj` or `schema` must be specified")

if database is not None and database != self.current_database:
raise com.UnsupportedOperationError(
"Creating tables in other databases is not supported by Postgres"
if temp and overwrite:
raise ValueError(
"MSSQL doesn't support overwriting temp tables, create a new temp table instead."
)
else:
database = None

table_loc = self._to_sqlglot_table(database)
catalog, db = self._to_catalog_db_tuple(table_loc)

properties = []

if temp:
properties.append(sge.TemporaryProperty())
catalog, db = None, None

temp_memtable_view = None
if obj is not None:
Expand Down Expand Up @@ -528,8 +570,10 @@ def create_table(
else:
temp_name = name

table = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted)
raw_table = sg.table(temp_name, catalog=database, quoted=False)
table = sg.table(
"#" * temp + temp_name, catalog=catalog, db=db, quoted=self.compiler.quoted
)
raw_table = sg.table(temp_name, catalog=catalog, db=db, quoted=False)
target = sge.Schema(this=table, expressions=column_defs)

create_stmt = sge.Create(
Expand All @@ -538,11 +582,22 @@ def create_table(
properties=sge.Properties(expressions=properties),
)

this = sg.table(name, catalog=database, quoted=self.compiler.quoted)
raw_this = sg.table(name, catalog=database, quoted=False)
this = sg.table(name, catalog=catalog, db=db, quoted=self.compiler.quoted)
raw_this = sg.table(name, catalog=catalog, db=db, quoted=False)
with self._safe_raw_sql(create_stmt) as cur:
if query is not None:
insert_stmt = sge.Insert(this=table, expression=query).sql(self.dialect)
# You can specify that a table is temporary for the sqlglot `Create` but not
# for the subsequent `Insert`, so we need to shove a `#` in
# front of the table identifier.
_table = sg.table(
"##" * temp + temp_name,
catalog=catalog,
db=db,
quoted=self.compiler.quoted,
)
insert_stmt = sge.Insert(this=_table, expression=query).sql(
self.dialect
)
cur.execute(insert_stmt)

if overwrite:
Expand All @@ -558,11 +613,17 @@ def create_table(
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)
return self.table(name, database=database)
return self.table(
"##" * temp + name,
database=("tempdb" * temp or catalog, "dbo" * temp or db),
)

# preserve the input schema if it was provided
return ops.DatabaseTable(
name, schema=schema, source=self, namespace=ops.Namespace(database=database)
name,
schema=schema,
source=self,
namespace=ops.Namespace(catalog=catalog, database=db),
).to_expr()

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
Expand Down
18 changes: 18 additions & 0 deletions ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,21 @@ def test_list_tables_schema_warning_refactor(con):

assert con.list_tables(database="msdb.dbo", like="restore") == restore_tables
assert con.list_tables(database=("msdb", "dbo"), like="restore") == restore_tables


def test_create_temp_table_from_obj(con):
obj = {"team": ["john", "joe"]}

t = con.create_table("team", obj, temp=True)

t2 = con.table("##team", database="tempdb.dbo")

assert t.to_pyarrow().equals(t2.to_pyarrow())

persisted_from_temp = con.create_table("fuhreal", t2)

assert "fuhreal" in con.list_tables()

assert persisted_from_temp.to_pyarrow().equals(t2.to_pyarrow())

con.drop_table("fuhreal")
7 changes: 5 additions & 2 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ def test_create_table(backend, con, temp_table, func, sch):
["pyspark", "trino", "exasol", "risingwave"],
reason="No support for temp tables",
),
pytest.mark.broken(["mssql"], reason="Incorrect temp table syntax"),
pytest.mark.notyet(
["mssql"],
reason="Can't rename temp tables",
raises=ValueError,
),
pytest.mark.broken(
["bigquery"],
reason="tables created with temp=True cause a 404 on retrieval",
Expand Down Expand Up @@ -1722,7 +1726,6 @@ def test_json_to_pyarrow(con):
assert result == expected


@pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError)
@pytest.mark.notyet(
["risingwave", "exasol"],
raises=com.UnsupportedOperationError,
Expand Down

0 comments on commit 196d8a1

Please sign in to comment.