Skip to content

Commit

Permalink
fix(duckdb): remove hack to workaround bug that was fixed upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and jcrist committed Sep 12, 2023
1 parent 98b348c commit 310c521
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 36 deletions.
37 changes: 3 additions & 34 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
from functools import partial
from typing import TYPE_CHECKING, Any

import duckdb
import numpy as np
import sqlalchemy as sa
from packaging.version import parse as vparse
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import GenericFunction
from toolz.curried import flip
Expand All @@ -23,9 +21,6 @@
reduction,
try_cast,
)
from ibis.backends.base.sql.alchemy.registry import (
_translate_case as _base_translate_case,
)
from ibis.backends.postgres.registry import (
_array_index,
_array_slice,
Expand All @@ -44,8 +39,6 @@
for op in operation_registry.keys() - geospatial_functions.keys()
}

_SUPPORTS_MAPS = vparse(duckdb.__version__) >= vparse("0.8.0")


def _round(t, op):
arg, digits = op.args
Expand Down Expand Up @@ -224,26 +217,6 @@ def _struct_column(t, op):
)


def _simple_case(t, op):
return _translate_case(t, op, value=t.translate(op.base))


def _searched_case(t, op):
return _translate_case(t, op, value=None)


def _translate_case(t, op, *, value):
return sa.literal_column(
str(
_base_translate_case(t, op, value=value).compile(
dialect=sa.dialects.registry.load("duckdb")(),
compile_kwargs=dict(literal_binds=True),
)
),
type_=t.get_sqla_type(op.dtype),
)


@compiles(array_map, "duckdb")
def compiles_list_apply(element, compiler, **kw):
*args, signature, result = map(partial(compiler.process, **kw), element.clauses)
Expand Down Expand Up @@ -460,8 +433,6 @@ def _try_cast(t, op):
ops.ArrayStringJoin: fixed_arity(
lambda sep, arr: sa.func.array_aggr(arr, sa.text("'string_agg'"), sep), 2
),
ops.SearchedCase: _searched_case,
ops.SimpleCase: _simple_case,
ops.StartsWith: fixed_arity(sa.func.prefix, 2),
ops.EndsWith: fixed_arity(sa.func.suffix, 2),
ops.Argument: lambda _, op: sa.literal_column(op.name),
Expand All @@ -477,11 +448,9 @@ def _try_cast(t, op):
lambda arg, key: sa.func.array_length(sa.func.element_at(arg, key)) != 0, 2
),
ops.MapLength: unary(sa.func.cardinality),
ops.MapKeys: unary(sa.func.map_keys) if _SUPPORTS_MAPS else _map_keys,
ops.MapValues: unary(sa.func.map_values) if _SUPPORTS_MAPS else _map_values,
ops.MapMerge: (
fixed_arity(sa.func.map_concat, 2) if _SUPPORTS_MAPS else _map_merge
),
ops.MapKeys: unary(sa.func.map_keys),
ops.MapValues: unary(sa.func.map_values),
ops.MapMerge: fixed_arity(sa.func.map_concat, 2),
ops.Hash: unary(sa.func.hash),
ops.Median: reduction(sa.func.median),
ops.First: reduction(sa.func.first),
Expand Down
23 changes: 23 additions & 0 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,3 +1051,26 @@ def test_levenshtein(con, right):
expr = left.levenshtein(right)
result = con.execute(expr)
assert result == 3


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(
["mssql"],
reason="doesn't allow boolean expressions in select statements",
raises=sa.exc.OperationalError,
)
@pytest.mark.notyet(["druid"], raises=sa.exc.ProgrammingError)
@pytest.mark.broken(
["oracle"],
reason="sqlalchemy converts True to 1, which cannot be used in CASE WHEN statement",
raises=sa.exc.DatabaseError,
)
@pytest.mark.parametrize(
"expr",
[
param(ibis.case().when(True, "%").end(), id="case"),
param(ibis.ifelse(True, "%", ibis.NA), id="ifelse"),
],
)
def test_no_conditional_percent_escape(con, expr):
assert con.execute(expr) == "%"
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ dask = { version = ">=2022.9.1", optional = true, extras = [
datafusion = { version = ">=0.6,<29", optional = true }
db-dtypes = { version = ">=0.3,<2", optional = true }
deltalake = { version = ">=0.9.0,<1", optional = true }
duckdb = { version = ">=0.3.3,<1", optional = true }
duckdb = { version = ">=0.8.1,<1", optional = true }
duckdb-engine = { version = ">=0.1.8,<1", optional = true }
fsspec = { version = ">=2022.1.0", optional = true }
GeoAlchemy2 = { version = ">=0.6.3,<1,!=0.13.0,!=0.14.0,!=0.14.1", optional = true }
Expand Down

0 comments on commit 310c521

Please sign in to comment.