Skip to content

Commit

Permalink
feat(mssql): use odbc
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Dec 19, 2023
1 parent ca6c2a5 commit f03ad0c
Show file tree
Hide file tree
Showing 17 changed files with 135 additions and 163 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check-generated-files.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: sudo apt-get update -y -q

- name: install system dependencies
run: sudo apt-get install -y -q build-essential graphviz libgeos-dev libkrb5-dev krb5-config freetds-dev
run: sudo apt-get install -y -q build-essential graphviz libgeos-dev freetds-dev unixodbc-dev

- name: install poetry
run: pip install 'poetry==1.7.1'
Expand Down
49 changes: 36 additions & 13 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ jobs:
runs-on: ${{ matrix.os }}
env:
SQLALCHEMY_WARN_20: "1"
ODBCSYSINI: "${{ github.workspace }}/.odbc"
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -130,15 +131,14 @@ jobs:
- ninja-build
- name: mssql
title: MS SQL Server
serial: true
extras:
- mssql
services:
- mssql
sys-deps:
- libkrb5-dev
- krb5-config
- freetds-dev
- unixodbc-dev
- tdsodbc
- name: trino
title: Trino
extras:
Expand Down Expand Up @@ -237,15 +237,14 @@ jobs:
backend:
name: mssql
title: MS SQL Server
serial: true
extras:
- mssql
services:
- mssql
sys-deps:
- libkrb5-dev
- krb5-config
- freetds-dev
- unixodbc-dev
- tdsodbc
- os: windows-latest
backend:
name: trino
Expand Down Expand Up @@ -315,6 +314,16 @@ jobs:
- name: checkout
uses: actions/checkout@v4

- name: setup odbc for mssql
if: ${{ matrix.backend.name == 'mssql' }}
run: |
mkdir -p "$ODBCSYSINI"
{
echo '[FreeTDS]'
echo "Driver = libtdsodbc.so"
} > "$ODBCSYSINI/odbcinst.ini"
- uses: extractions/setup-just@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down Expand Up @@ -664,6 +673,10 @@ jobs:
- mssql
extras:
- mssql
sys-deps:
- freetds-dev
- unixodbc-dev
- tdsodbc
- name: mysql
title: MySQL
services:
Expand All @@ -678,6 +691,8 @@ jobs:
extras:
- geospatial
- postgres
sys-deps:
- libgeos-dev
- name: sqlite
title: SQLite
extras:
Expand All @@ -700,21 +715,29 @@ jobs:
- oracle
services:
- oracle
env:
ODBCSYSINI: "${{ github.workspace }}/.odbc"
steps:
- name: checkout
uses: actions/checkout@v4

- name: install libgeos for shapely
if: ${{ matrix.backend.name == 'postgres' }}
- name: update and install system dependencies
if: matrix.backend.sys-deps != null
run: |
sudo apt-get update -y -qq
sudo apt-get install -qq -y build-essential libgeos-dev
set -euo pipefail
- name: install freetds-dev for mssql
sudo apt-get update -qq -y
sudo apt-get install -qq -y build-essential ${{ join(matrix.backend.sys-deps, ' ') }}
- name: setup odbc for mssql
if: ${{ matrix.backend.name == 'mssql' }}
run: |
sudo apt-get update -y -qq
sudo apt-get install -qq -y build-essential libkrb5-dev krb5-config freetds-dev
mkdir -p "$ODBCSYSINI"
{
echo '[FreeTDS]'
echo "Driver = libtdsodbc.so"
} > "$ODBCSYSINI/odbcinst.ini"
- uses: extractions/setup-just@v1
env:
Expand Down
4 changes: 1 addition & 3 deletions .github/workflows/ibis-docs-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ jobs:
python-version: "3.11"

- name: install system dependencies
run: |
sudo apt-get update -y -qq
sudo apt-get install -qq -y build-essential libgeos-dev freetds-dev libkrb5-dev krb5-config
run: sudo apt-get install -qq -y build-essential libgeos-dev freetds-dev unixodbc-dev

- uses: syphar/restore-virtualenv@v1
with:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ibis-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ jobs:
run: |
set -euo pipefail
sudo apt-get update -y -qq
sudo apt-get install -y -q build-essential graphviz libgeos-dev libkrb5-dev freetds-dev
sudo apt-get update -y -q
sudo apt-get install -y -q build-essential graphviz libgeos-dev freetds-dev unixodbc-dev
- name: checkout
uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,4 @@ ibis/examples/descriptions

# chat
*zuliprc*
.odbc
9 changes: 9 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
duckdb
# mysql
mariadb-client
# pyodbc setup debugging
# in particular: odbcinst -j
unixODBC
# pyspark
openjdk17_headless
# postgres client
Expand Down Expand Up @@ -111,6 +114,12 @@
MSSQL_SA_PASSWORD = "1bis_Testing!";
DRUID_URL = "druid://localhost:8082/druid/v2/sql";

# needed for mssql+pyodbc
ODBCSYSINI = pkgs.writeTextDir "odbcinst.ini" ''
[FreeTDS]
Driver = ${pkgs.lib.makeLibraryPath [ pkgs.freetds ]}/libtdsodbc.so
'';

__darwinAllowLocalNetworking = true;
};
in
Expand Down
22 changes: 17 additions & 5 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import abc
import atexit
import contextlib
import getpass
import warnings
from operator import methodcaller
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -136,18 +135,28 @@ def _compile_type(self, dtype) -> str:
self.compiler.translator_class.get_sqla_type(dtype)
).compile(dialect=dialect)

def _build_alchemy_url(self, url, host, port, user, password, database, driver):
def _build_alchemy_url(
self,
url: str | None,
host: str | None,
port: int | None,
user: str | None,
password: str | None,
database: str | None,
driver: str | None,
query: Mapping[str, Any] | None = None,
) -> sa.engine.URL:
if url is not None:
return sa.engine.url.make_url(url)

user = user or getpass.getuser()
return sa.engine.url.URL.create(
driver,
host=host,
port=port,
username=user,
password=password,
database=database,
query=query or {},
)

@property
Expand Down Expand Up @@ -875,8 +884,11 @@ def _get_compiled_statement(
compiled = definition.compile(
dialect=self.con.dialect, compile_kwargs=compile_kwargs
)
lines = self._get_temp_view_definition(name, definition=compiled)
return lines, compiled.params
create_view = self._get_temp_view_definition(name, definition=compiled)
params = compiled.params
if compiled.positional:
params = tuple(params.values())
return create_view, params

def _create_temp_view(self, view: sa.Table, definition: sa.sql.Selectable) -> None:
raw_name = view.name
Expand Down
65 changes: 27 additions & 38 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal
import contextlib
from typing import TYPE_CHECKING, Any

import sqlalchemy as sa
import sqlglot as sg
import toolz

from ibis.backends.base import CanCreateDatabase
Expand Down Expand Up @@ -35,21 +35,30 @@ def do_connect(
port: int = 1433,
database: str | None = None,
url: str | None = None,
driver: Literal["pymssql"] = "pymssql",
query: Mapping[str, Any] | None = None,
driver: str | None = None,
**kwargs: Any,
) -> None:
if driver != "pymssql":
raise NotImplementedError("pymssql is currently the only supported driver")
if query is None:
query = {}

if driver is not None:
query["driver"] = driver

alchemy_url = self._build_alchemy_url(
url=url,
host=host,
port=port,
user=user,
password=password,
database=database,
driver=f"mssql+{driver}",
driver="mssql+pyodbc",
query=query,
)

engine = sa.create_engine(alchemy_url, poolclass=sa.pool.StaticPool)
engine = sa.create_engine(
alchemy_url, poolclass=sa.pool.StaticPool, connect_args=kwargs
)

@sa.event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
Expand Down Expand Up @@ -85,40 +94,20 @@ def list_databases(self, like: str | None = None) -> list[str]:
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.func.schema_name()))

def list_tables(
self,
like: str | None = None,
database: str | None = None,
schema: str | None = None,
) -> list[str]:
tablequery = sg.select("name").from_(
sg.table("tables", db="sys", catalog=database)
@contextlib.contextmanager
def _safe_raw_sql(self, stmt, *args, **kwargs):
sql = str(
stmt.compile(
dialect=self.con.dialect, compile_kwargs={"literal_binds": True}
)
)
viewquery = sg.select("name").from_(
sg.table("views", db="sys", catalog=database)
)

if schema is not None:
table_predicate = sg.func(
"schema_name",
sg.column("schema_id", table="tables", db="sys", catalog=database),
).eq(schema)
view_predicate = sg.func(
"schema_name",
sg.column("schema_id", table="views", db="sys", catalog=database),
).eq(schema)
tablequery = tablequery.where(table_predicate)
viewquery = viewquery.where(view_predicate)

tablequery = sa.text(tablequery.sql(dialect="tsql"))
viewquery = sa.text(viewquery.sql(dialect="tsql"))

with self.begin() as con:
tablequery = list(con.execute(tablequery).scalars())
viewresults = list(con.execute(viewquery).scalars())
results = tablequery + viewresults
yield con.exec_driver_sql(sql, *args, **kwargs)

return self._filter_with_like(results, like)
def _get_compiled_statement(self, view: sa.Table, definition: sa.sql.Selectable):
return super()._get_compiled_statement(
view, definition, compile_kwargs={"literal_binds": True}
)

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
Expand Down
12 changes: 8 additions & 4 deletions ibis/backends/mssql/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MSSQL_HOST = os.environ.get("IBIS_TEST_MSSQL_HOST", "localhost")
MSSQL_PORT = int(os.environ.get("IBIS_TEST_MSSQL_PORT", 1433))
IBIS_TEST_MSSQL_DB = os.environ.get("IBIS_TEST_MSSQL_DATABASE", "ibis_testing")
MSSQL_PYODBC_DRIVER = os.environ.get("IBIS_TEST_MSSQL_PYODBC_DRIVER", "FreeTDS")


class TestConf(ServiceBackendTest):
Expand All @@ -32,7 +33,7 @@ class TestConf(ServiceBackendTest):
supports_json = False
rounding_method = "half_to_even"
service_name = "mssql"
deps = "pymssql", "sqlalchemy"
deps = "pyodbc", "sqlalchemy"

@property
def test_files(self) -> Iterable[Path]:
Expand All @@ -57,10 +58,12 @@ def _load_data(
script_dir
Location of scripts defining schemas
"""
params = f"driver={MSSQL_PYODBC_DRIVER}"
url = sa.engine.make_url(
f"mssql+pyodbc://{user}:{password}@{host}:{port:d}/{database}?{params}"
)
init_database(
url=sa.engine.make_url(
f"mssql+pymssql://{user}:{password}@{host}:{port:d}/{database}"
),
url=url,
database=database,
schema=self.ddl_script,
isolation_level="AUTOCOMMIT",
Expand All @@ -75,6 +78,7 @@ def connect(*, tmpdir, worker_id, **kw):
password=MSSQL_PASS,
database=IBIS_TEST_MSSQL_DB,
port=MSSQL_PORT,
driver=MSSQL_PYODBC_DRIVER,
**kw,
)

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def count_big(x, where: bool = True) -> int:
ft = con.tables.functional_alltypes
expr = count_big(ft.id)
with pytest.raises(
sa.exc.OperationalError, match="An expression of non-boolean type specified"
sa.exc.ProgrammingError, match="An expression of non-boolean type specified"
):
assert expr.execute()

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def test_list_databases(alchemy_con):
@pytest.mark.never(
["bigquery", "postgres", "mssql", "mysql", "snowflake", "oracle"],
reason="backend does not support client-side in-memory tables",
raises=(sa.exc.OperationalError, TypeError),
raises=(sa.exc.OperationalError, TypeError, sa.exc.InterfaceError),
)
@pytest.mark.notyet(
["trino"], reason="memory connector doesn't allow writing to tables"
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def test_select_filter_select(backend, alltypes, df):


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(["mssql"], raises=sa.exc.OperationalError)
@pytest.mark.broken(["mssql"], raises=sa.exc.ProgrammingError)
def test_between(backend, alltypes, df):
expr = alltypes.double_col.between(5, 10)
result = expr.execute().rename("double_col")
Expand Down
Loading

0 comments on commit f03ad0c

Please sign in to comment.