Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mssql): fix temporary table creation and implement cache #9434

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 73 additions & 12 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def do_connect(
def get_schema(
self, name: str, *, catalog: str | None = None, database: str | None = None
) -> sch.Schema:
# TODO: this is brittle and should be improved. We want to be able to
# identify if a given table is a temp table and update the search
# location accordingly.
if name.startswith("ibis_cache_"):
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
catalog, database = ("tempdb", "dbo")
name = "##" + name
conditions = [sg.column("table_name").eq(sge.convert(name))]

if database is not None:
Expand Down Expand Up @@ -481,20 +487,56 @@ def create_table(
temp: bool = False,
overwrite: bool = False,
) -> ir.Table:
"""Create a new table.

Parameters
----------
name
Name of the new table.
obj
An Ibis table expression or pandas table that will be used to
extract the schema and the data of the new table. If not provided,
`schema` must be given.
schema
The schema for the new table. Only one of `schema` or `obj` can be
provided.
database
Name of the database where the table will be created, if not the
default.

To specify a location in a separate catalog, you can pass in the
catalog and database as a string `"catalog.database"`, or as a tuple of
strings `("catalog", "database")`.
temp
Whether a table is temporary or not.
All created temp tables are "Global Temporary Tables". They will be
created in "tempdb.dbo" and will be prefixed with "##".
overwrite
Whether to clobber existing data.
`overwrite` and `temp` cannot be used together with MSSQL.

Returns
-------
Table
The table that was created.

"""
if obj is None and schema is None:
raise ValueError("Either `obj` or `schema` must be specified")

if database is not None and database != self.current_database:
raise com.UnsupportedOperationError(
"Creating tables in other databases is not supported by Postgres"
if temp and overwrite:
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"MSSQL doesn't support overwriting temp tables, create a new temp table instead."
)
else:
database = None

table_loc = self._to_sqlglot_table(database)
catalog, db = self._to_catalog_db_tuple(table_loc)

properties = []

if temp:
properties.append(sge.TemporaryProperty())
catalog, db = None, None

temp_memtable_view = None
if obj is not None:
Expand Down Expand Up @@ -528,8 +570,10 @@ def create_table(
else:
temp_name = name

table = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted)
raw_table = sg.table(temp_name, catalog=database, quoted=False)
table = sg.table(
"#" * temp + temp_name, catalog=catalog, db=db, quoted=self.compiler.quoted
)
raw_table = sg.table(temp_name, catalog=catalog, db=db, quoted=False)
target = sge.Schema(this=table, expressions=column_defs)

create_stmt = sge.Create(
Expand All @@ -538,11 +582,22 @@ def create_table(
properties=sge.Properties(expressions=properties),
)

this = sg.table(name, catalog=database, quoted=self.compiler.quoted)
raw_this = sg.table(name, catalog=database, quoted=False)
this = sg.table(name, catalog=catalog, db=db, quoted=self.compiler.quoted)
raw_this = sg.table(name, catalog=catalog, db=db, quoted=False)
with self._safe_raw_sql(create_stmt) as cur:
if query is not None:
insert_stmt = sge.Insert(this=table, expression=query).sql(self.dialect)
# You can specify that a table is temporary for the sqlglot `Create` but not
# for the subsequent `Insert`, so we need to shove a `#` in
# front of the table identifier.
_table = sg.table(
"##" * temp + temp_name,
catalog=catalog,
db=db,
quoted=self.compiler.quoted,
)
insert_stmt = sge.Insert(this=_table, expression=query).sql(
self.dialect
)
cur.execute(insert_stmt)

if overwrite:
Expand All @@ -558,11 +613,17 @@ def create_table(
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)
return self.table(name, database=database)
return self.table(
"##" * temp + name,
database=("tempdb" * temp or catalog, "dbo" * temp or db),
)

# preserve the input schema if it was provided
return ops.DatabaseTable(
name, schema=schema, source=self, namespace=ops.Namespace(database=database)
name,
schema=schema,
source=self,
namespace=ops.Namespace(catalog=catalog, database=db),
).to_expr()

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
Expand Down
18 changes: 18 additions & 0 deletions ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,21 @@ def test_list_tables_schema_warning_refactor(con):

assert con.list_tables(database="msdb.dbo", like="restore") == restore_tables
assert con.list_tables(database=("msdb", "dbo"), like="restore") == restore_tables


def test_create_temp_table_from_obj(con):
obj = {"team": ["john", "joe"]}

t = con.create_table("team", obj, temp=True)

t2 = con.table("##team", database="tempdb.dbo")

assert t.to_pyarrow().equals(t2.to_pyarrow())

persisted_from_temp = con.create_table("fuhreal", t2)

assert "fuhreal" in con.list_tables()

assert persisted_from_temp.to_pyarrow().equals(t2.to_pyarrow())

con.drop_table("fuhreal")
7 changes: 5 additions & 2 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ def test_create_table(backend, con, temp_table, func, sch):
["pyspark", "trino", "exasol", "risingwave"],
reason="No support for temp tables",
),
pytest.mark.broken(["mssql"], reason="Incorrect temp table syntax"),
pytest.mark.notyet(
["mssql"],
reason="Can't rename temp tables",
raises=ValueError,
),
pytest.mark.broken(
["bigquery"],
reason="tables created with temp=True cause a 404 on retrieval",
Expand Down Expand Up @@ -1722,7 +1726,6 @@ def test_json_to_pyarrow(con):
assert result == expected


@pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError)
@pytest.mark.notyet(
["risingwave", "exasol"],
raises=com.UnsupportedOperationError,
Expand Down