Skip to content

Commit

Permalink
fix(mssql): remove sort key to keep order (#9848)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
grieve54706 and cpcloud authored Aug 15, 2024
1 parent c99cb4b commit 3780a13
Showing 2 changed files with 65 additions and 29 deletions.
25 changes: 11 additions & 14 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
@@ -244,24 +244,21 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
# us to pre-filter the columns we want back.
# The syntax is:
# `sys.dm_exec_describe_first_result_set(@tsql, @params, @include_browse_information)`
query = f"""SELECT name,
is_nullable AS nullable,
system_type_name,
precision,
scale
FROM
sys.dm_exec_describe_first_result_set({tsql}, NULL, 0)"""
query = f"""
SELECT
name,
is_nullable,
system_type_name,
precision,
scale
FROM sys.dm_exec_describe_first_result_set({tsql}, NULL, 0)
ORDER BY column_ordinal
"""
with self._safe_raw_sql(query) as cur:
rows = cur.fetchall()

schema = {}
for (
name,
nullable,
system_type_name,
precision,
scale,
) in sorted(rows, key=itemgetter(1)):
for name, nullable, system_type_name, precision, scale in rows:
newtyp = self.compiler.type_mapper.from_string(
system_type_name, nullable=nullable
)
69 changes: 54 additions & 15 deletions ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import pytest
import sqlglot as sg
import sqlglot.expressions as sge
from pytest import param

import ibis
import ibis.expr.datatypes as dt
from ibis import udf

DB_TYPES = [
RAW_DB_TYPES = [
# Exact numbers
("BIGINT", dt.int64),
("BIT", dt.boolean),
@@ -36,23 +38,9 @@
("DATETIME", dt.Timestamp(scale=3)),
# Characters strings
("CHAR", dt.string),
param(
"TEXT",
dt.string,
marks=pytest.mark.notyet(
["mssql"], reason="Not supported by UTF-8 aware collations"
),
),
("VARCHAR", dt.string),
# Unicode character strings
("NCHAR", dt.string),
param(
"NTEXT",
dt.string,
marks=pytest.mark.notyet(
["mssql"], reason="Not supported by UTF-8 aware collations"
),
),
("NVARCHAR", dt.string),
# Binary strings
("BINARY", dt.binary),
@@ -67,6 +55,23 @@
("GEOGRAPHY", dt.geography),
("HIERARCHYID", dt.string),
]
PARAM_TYPES = [
param(
"TEXT",
dt.string,
marks=pytest.mark.notyet(
["mssql"], reason="Not supported by UTF-8 aware collations"
),
),
param(
"NTEXT",
dt.string,
marks=pytest.mark.notyet(
["mssql"], reason="Not supported by UTF-8 aware collations"
),
),
]
DB_TYPES = RAW_DB_TYPES + PARAM_TYPES


@pytest.mark.parametrize(("server_type", "expected_type"), DB_TYPES, ids=str)
@@ -81,6 +86,40 @@ def test_get_schema(con, server_type, expected_type, temp_table):
assert con.sql(f"SELECT * FROM [{temp_table}]").schema() == expected_schema


def test_schema_type_order(con, temp_table):
columns = []
pairs = {}

quoted = con.compiler.quoted
dialect = con.dialect
table_id = sg.to_identifier(temp_table, quoted=quoted)

for i, (server_type, expected_type) in enumerate(RAW_DB_TYPES):
column_name = f"col_{i}"
columns.append(
sge.ColumnDef(
this=sg.to_identifier(column_name, quoted=quoted), kind=server_type
)
)
pairs[column_name] = expected_type

query = sge.Create(
kind="TABLE", this=sge.Schema(this=table_id, expressions=columns)
)
stmt = query.sql(dialect)

with con.begin() as c:
c.execute(stmt)

expected_schema = ibis.schema(pairs)

assert con.get_schema(temp_table) == expected_schema
assert con.table(temp_table).schema() == expected_schema

raw_sql = sg.select("*").from_(table_id).sql(dialect)
assert con.sql(raw_sql).schema() == expected_schema


def test_builtin_scalar_udf(con):
@udf.scalar.builtin
def difference(a: str, b: str) -> int:

0 comments on commit 3780a13

Please sign in to comment.