Skip to content

Commit

Permalink
feat(mssql): use odbc
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Oct 10, 2023
1 parent 90b4bf7 commit f09ba0f
Show file tree
Hide file tree
Showing 22 changed files with 253 additions and 177 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check-generated-files.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: sudo apt-get update -y -q

- name: install system dependencies
run: sudo apt-get install -y -q build-essential graphviz libgeos-dev libkrb5-dev krb5-config freetds-dev
run: sudo apt-get install -y -q build-essential graphviz libgeos-dev freetds-dev unixodbc-dev

- name: install poetry
run: pip install 'poetry==1.6.1'
Expand Down
49 changes: 38 additions & 11 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ jobs:
- gen_lockfile_backends
env:
SQLALCHEMY_WARN_20: "1"
ODBCSYSINI: "${{ github.workspace }}/.odbc"
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -172,15 +173,14 @@ jobs:
- ninja-build
- name: mssql
title: MS SQL Server
serial: true
extras:
- mssql
services:
- mssql
sys-deps:
- libkrb5-dev
- krb5-config
- freetds-dev
- unixodbc-dev
- tdsodbc
- name: trino
title: Trino
extras:
Expand Down Expand Up @@ -271,15 +271,14 @@ jobs:
backend:
name: mssql
title: MS SQL Server
serial: true
extras:
- mssql
services:
- mssql
sys-deps:
- libkrb5-dev
- krb5-config
- freetds-dev
- unixodbc-dev
- tdsodbc
- os: windows-latest
backend:
name: trino
Expand Down Expand Up @@ -333,6 +332,16 @@ jobs:
- name: checkout
uses: actions/checkout@v4

- name: setup odbc for mssql
if: ${{ matrix.backend.name == 'mssql' }}
run: |
mkdir -p "$ODBCSYSINI"
{
echo '[FreeTDS]'
echo "Driver = libtdsodbc.so"
} > "$ODBCSYSINI/odbcinst.ini"
- uses: extractions/setup-just@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down Expand Up @@ -699,6 +708,10 @@ jobs:
- mssql
extras:
- mssql
sys-deps:
- freetds-dev
- unixodbc-dev
- tdsodbc
- name: mysql
title: MySQL
services:
Expand All @@ -713,6 +726,8 @@ jobs:
extras:
- geospatial
- postgres
sys-deps:
- libgeos-dev
- name: sqlite
title: SQLite
extras:
Expand All @@ -735,17 +750,29 @@ jobs:
- oracle
services:
- oracle
env:
ODBCSYSINI: "${{ github.workspace }}/.odbc"
steps:
- name: checkout
uses: actions/checkout@v4

- name: install libgeos for shapely
if: ${{ matrix.backend.name == 'postgres' }}
run: sudo apt-get install -qq -y build-essential libgeos-dev
- name: update and install system dependencies
if: matrix.backend.sys-deps != null
run: |
set -euo pipefail
- name: install freetds-dev for mssql
sudo apt-get update -qq -y
sudo apt-get install -qq -y build-essential ${{ join(matrix.backend.sys-deps, ' ') }}
- name: setup odbc for mssql
if: ${{ matrix.backend.name == 'mssql' }}
run: sudo apt-get install -qq -y build-essential libkrb5-dev krb5-config freetds-dev
run: |
mkdir -p "$ODBCSYSINI"
{
echo '[FreeTDS]'
echo "Driver = libtdsodbc.so"
} > "$ODBCSYSINI/odbcinst.ini"
- uses: extractions/setup-just@v1
env:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ibis-docs-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:
python-version: "3.11"

- name: install system dependencies
run: sudo apt-get install -qq -y build-essential libgeos-dev freetds-dev libkrb5-dev krb5-config
run: sudo apt-get install -qq -y build-essential libgeos-dev freetds-dev unixodbc-dev

- uses: syphar/restore-virtualenv@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ibis-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ jobs:
set -euo pipefail
sudo apt-get update -y -q
sudo apt-get install -y -q build-essential graphviz libgeos-dev libkrb5-dev freetds-dev
sudo apt-get install -y -q build-essential graphviz libgeos-dev freetds-dev unixodbc-dev
- name: checkout
uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,4 @@ ibis/examples/descriptions

# chat
*zuliprc*
.odbc
16 changes: 16 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@
overlays = [ self.overlays.default ];
};

odbcInstIni = pkgs.writeTextDir "odbcinst.ini" ''
[FreeTDS]
Driver = ${pkgs.lib.makeLibraryPath [ pkgs.freetds ]}/libtdsodbc.so
'';

odbcSysIni = pkgs.symlinkJoin {
name = "odbc-config";
paths = [ odbcInstIni ];
};

backendDevDeps = with pkgs; [
# impala UDFs
clang_15
Expand All @@ -54,6 +64,9 @@
duckdb
# mysql
mycli
# pyodbc setup debugging
# in particular: odbcinst -j
unixODBC
# pyspark
openjdk17_headless
# postgres client
Expand Down Expand Up @@ -116,6 +129,9 @@
MSSQL_SA_PASSWORD = "1bis_Testing!";
DRUID_URL = "druid://localhost:8082/druid/v2/sql";

# needed for mssql+pyodbc
ODBCSYSINI = "${odbcSysIni}";

__darwinAllowLocalNetworking = true;
};
in
Expand Down
22 changes: 17 additions & 5 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import abc
import atexit
import contextlib
import getpass
import warnings
from operator import methodcaller
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -134,18 +133,28 @@ def _compile_type(self, dtype) -> str:
self.compiler.translator_class.get_sqla_type(dtype)
).compile(dialect=dialect)

def _build_alchemy_url(self, url, host, port, user, password, database, driver):
def _build_alchemy_url(
self,
url: str | None,
host: str | None,
port: int | None,
user: str | None,
password: str | None,
database: str | None,
driver: str | None,
query: Mapping[str, Any] | None = None,
) -> sa.engine.URL:
if url is not None:
return sa.engine.url.make_url(url)

user = user or getpass.getuser()
return sa.engine.url.URL.create(
driver,
host=host,
port=port,
username=user,
password=password,
database=database,
query=query or {},
)

@property
Expand Down Expand Up @@ -817,8 +826,11 @@ def _get_compiled_statement(
compiled = definition.compile(
dialect=self.con.dialect, compile_kwargs=compile_kwargs
)
lines = self._get_temp_view_definition(name, definition=compiled)
return lines, compiled.params
create_view = self._get_temp_view_definition(name, definition=compiled)
params = compiled.params
if compiled.positional:
params = tuple(params.values())
return create_view, params

def _create_temp_view(self, view: sa.Table, definition: sa.sql.Selectable) -> None:
raw_name = view.name
Expand Down
37 changes: 31 additions & 6 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal
import contextlib
from typing import TYPE_CHECKING, Any

import sqlalchemy as sa
import toolz
Expand Down Expand Up @@ -34,21 +35,30 @@ def do_connect(
port: int = 1433,
database: str | None = None,
url: str | None = None,
driver: Literal["pymssql"] = "pymssql",
query: Mapping[str, Any] | None = None,
driver: str | None = None,
**kwargs: Any,
) -> None:
if driver != "pymssql":
raise NotImplementedError("pymssql is currently the only supported driver")
if query is None:
query = {}

if driver is not None:
query["driver"] = driver

alchemy_url = self._build_alchemy_url(
url=url,
host=host,
port=port,
user=user,
password=password,
database=database,
driver=f"mssql+{driver}",
driver="mssql+pyodbc",
query=query,
)

engine = sa.create_engine(alchemy_url, poolclass=sa.pool.StaticPool)
engine = sa.create_engine(
alchemy_url, poolclass=sa.pool.StaticPool, connect_args=kwargs
)

@sa.event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
Expand Down Expand Up @@ -84,6 +94,21 @@ def list_databases(self, like: str | None = None) -> list[str]:
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.func.schema_name()))

@contextlib.contextmanager
def _safe_raw_sql(self, stmt, *args, **kwargs):
sql = str(
stmt.compile(
dialect=self.con.dialect, compile_kwargs={"literal_binds": True}
)
)
with self.begin() as con:
yield con.exec_driver_sql(sql, *args, **kwargs)

def _get_compiled_statement(self, view: sa.Table, definition: sa.sql.Selectable):
return super()._get_compiled_statement(
view, definition, compile_kwargs={"literal_binds": True}
)

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
) -> str:
Expand Down
40 changes: 39 additions & 1 deletion ibis/backends/mssql/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sqlalchemy as sa
from sqlalchemy.dialects import mssql
from sqlalchemy.ext.compiler import compiles

import ibis.common.exceptions as com
Expand All @@ -12,6 +13,9 @@
sqlalchemy_window_functions_registry,
unary,
)
from ibis.backends.base.sql.alchemy.registry import (
_literal as _alchemy_literal,
)
from ibis.backends.base.sql.alchemy.registry import substr, variance_reduction


Expand Down Expand Up @@ -112,6 +116,39 @@ def _temporal_delta(t, op):
return sa.func.datediff(sa.literal_column(op.part.value.upper()), right, left)


def _literal(t, op):
dtype = op.dtype
value = op.value

if value is not None:
if dtype.is_timestamp():
return sa.func.datetime2fromparts(
value.year,
value.month,
value.day,
value.hour,
value.minute,
value.second,
value.microsecond,
6,
)
elif dtype.is_date():
return sa.func.datefromparts(value.year, value.month, value.day)
elif dtype.is_time():
return sa.func.timefromparts(
value.hour,
value.minute,
value.second,
value.microsecond,
sa.literal_column("6"),
)
elif dtype.is_uuid():
return sa.cast(sa.literal(str(value)), mssql.UNIQUEIDENTIFIER)
elif dtype.is_binary():
return sa.cast(sa.literal(value), mssql.VARBINARY)
return _alchemy_literal(t, op)


operation_registry = sqlalchemy_operation_registry.copy()
operation_registry.update(sqlalchemy_window_functions_registry)

Expand Down Expand Up @@ -195,7 +232,7 @@ def _temporal_delta(t, op):
6,
),
ops.TimeFromHMS: fixed_arity(
lambda h, m, s: sa.func.timefromparts(h, m, s, 0, 0), 3
lambda h, m, s: sa.func.timefromparts(h, m, s, 0, sa.literal_column("0")), 3
),
ops.TimestampTruncate: _timestamp_truncate,
ops.DateTruncate: _timestamp_truncate,
Expand All @@ -206,6 +243,7 @@ def _temporal_delta(t, op):
ops.TimeDelta: _temporal_delta,
ops.DateDelta: _temporal_delta,
ops.TimestampDelta: _temporal_delta,
ops.Literal: _literal,
}
)

Expand Down
Loading

0 comments on commit f09ba0f

Please sign in to comment.