Skip to content

Commit

Permalink
feat(duckdb): add map operations
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Apr 10, 2023
1 parent 3a2c4df commit a4c4e77
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 57 deletions.
5 changes: 5 additions & 0 deletions ci/schema/duckdb.sql
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,8 @@ INSERT INTO win VALUES
('a', 2, 0),
('a', 3, 1),
('a', 4, 1);

CREATE OR REPLACE TABLE map (kv MAP(STRING, BIGINT));
INSERT INTO map VALUES
(MAP(['a', 'b', 'c'], [1, 2, 3])),
(MAP(['d', 'e', 'f'], [4, 5, 6]));
29 changes: 1 addition & 28 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@
from sqlalchemy.ext.compiler import compiles

import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import (
AlchemyCompiler,
AlchemyExprTranslator,
to_sqla_type,
)
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.duckdb.registry import operation_registry


Expand Down Expand Up @@ -45,28 +40,6 @@ def compile_array(element, compiler, **kw):
return f"{compiler.process(element.value_type, **kw)}[]"


try:
import duckdb_engine
except ImportError:
pass
else:

@dt.dtype.register(duckdb_engine.Dialect, sat.UInt64)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt32)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt16)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt8)
def dtype_uint(_, satype, nullable=True):
return getattr(dt, satype.__class__.__name__)(nullable=nullable)

@dt.dtype.register(duckdb_engine.Dialect, sat.ArrayType)
def _(dialect, satype, nullable=True):
return dt.Array(dt.dtype(dialect, satype.value_type), nullable=nullable)

@to_sqla_type.register(duckdb_engine.Dialect, dt.Array)
def _(dialect, itype):
return sat.ArrayType(to_sqla_type(dialect, itype.value_type))


rewrites = DuckDBSQLExprTranslator.rewrites


Expand Down
49 changes: 42 additions & 7 deletions ibis/backends/duckdb/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import parsy
import sqlalchemy as sa
import toolz
from duckdb_engine import Dialect as DuckDBDialect
from sqlalchemy.dialects import postgresql

import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy import to_sqla_type
from ibis.common.parsing import (
Expand Down Expand Up @@ -92,11 +92,46 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:
return ty.parse(text)


@to_sqla_type.register(DuckDBDialect, dt.UUID)
def sa_duckdb_uuid(*_):
return postgresql.UUID(as_uuid=True)
try:
from duckdb_engine import Dialect as DuckDBDialect
except ImportError:
pass
else:

@dt.dtype.register(DuckDBDialect, sat.UInt64)
@dt.dtype.register(DuckDBDialect, sat.UInt32)
@dt.dtype.register(DuckDBDialect, sat.UInt16)
@dt.dtype.register(DuckDBDialect, sat.UInt8)
def dtype_uint(_, satype, nullable=True):
return getattr(dt, satype.__class__.__name__)(nullable=nullable)

@to_sqla_type.register(DuckDBDialect, (dt.MACADDR, dt.INET))
def sa_duckdb_macaddr(*_):
return sa.TEXT()
@dt.dtype.register(DuckDBDialect, sat.ArrayType)
def _(dialect, satype, nullable=True):
return dt.Array(dt.dtype(dialect, satype.value_type), nullable=nullable)

@dt.dtype.register(DuckDBDialect, sat.MapType)
def _(dialect, satype, nullable=True):
return dt.Map(
dt.dtype(dialect, satype.key_type),
dt.dtype(dialect, satype.value_type),
nullable=nullable,
)

@to_sqla_type.register(DuckDBDialect, dt.UUID)
def sa_duckdb_uuid(*_):
return postgresql.UUID()

@to_sqla_type.register(DuckDBDialect, (dt.MACADDR, dt.INET))
def sa_duckdb_macaddr(*_):
return sa.TEXT()

@to_sqla_type.register(DuckDBDialect, dt.Map)
def sa_duckdb_map(dialect, itype):
return sat.MapType(
to_sqla_type(dialect, itype.key_type),
to_sqla_type(dialect, itype.value_type),
)

@to_sqla_type.register(DuckDBDialect, dt.Array)
def _(dialect, itype):
return sat.ArrayType(to_sqla_type(dialect, itype.value_type))
62 changes: 48 additions & 14 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,17 @@ def _literal(t, op):
elif dtype.is_string():
return sa.literal(value)
elif dtype.is_map():
raise NotImplementedError(
f"Ibis dtype `{dtype}` with mapping type "
f"`{type(value).__name__}` isn't yet supported with the duckdb "
"backend"
return sa.func.map(
sa.func.list_value(*value.keys()), sa.func.list_value(*value.values())
)
else:
return sa.cast(sa.literal(value), sqla_type)


if_ = getattr(sa.func, "if")


def _neg_idx_to_pos(array, idx):
if_ = getattr(sa.func, "if")
arg_length = sa.func.array_length(array)
return if_(idx < 0, arg_length + sa.func.greatest(idx, -arg_length), idx)

Expand Down Expand Up @@ -259,6 +259,32 @@ def _array_filter(t, op):
)


def _map_keys(t, op):
m = t.translate(op.arg)
return sa.cast(
sa.func.json_keys(sa.func.to_json(m)), t.get_sqla_type(op.output_dtype)
)


def _map_values(t, op):
m_json = sa.func.to_json(t.translate(op.arg))
return sa.cast(
sa.func.json_extract_string(m_json, sa.func.json_keys(m_json)),
t.get_sqla_type(op.output_dtype),
)


def _map_merge(t, op):
left = sa.func.to_json(t.translate(op.left))
right = sa.func.to_json(t.translate(op.right))
pairs = sa.func.json_merge_patch(left, right)
keys = sa.func.json_keys(pairs)
return sa.cast(
sa.func.map(keys, sa.func.json_extract_string(pairs, keys)),
t.get_sqla_type(op.output_dtype),
)


operation_registry.update(
{
ops.ArrayColumn: (
Expand Down Expand Up @@ -291,7 +317,6 @@ def _array_filter(t, op):
ops.Ln: unary(sa.func.ln),
ops.Log: _log,
ops.IsNan: unary(sa.func.isnan),
# TODO: map operations, but DuckDB's maps are multimaps
ops.Modulus: fixed_arity(operator.mod, 2),
ops.Round: _round,
ops.StructField: (
Expand Down Expand Up @@ -349,6 +374,20 @@ def _array_filter(t, op):
ops.ArrayFilter: _array_filter,
ops.Argument: lambda _, op: sa.literal_column(op.name),
ops.Unnest: unary(sa.func.unnest),
ops.MapGet: fixed_arity(
lambda arg, key, default: sa.func.coalesce(
sa.func.list_extract(sa.func.element_at(arg, key), 1), default
),
3,
),
ops.Map: fixed_arity(sa.func.map, 2),
ops.MapContains: fixed_arity(
lambda arg, key: sa.func.array_length(sa.func.element_at(arg, key)) != 0, 2
),
ops.MapLength: unary(sa.func.cardinality),
ops.MapKeys: _map_keys,
ops.MapValues: _map_values,
ops.MapMerge: _map_merge,
}
)

Expand All @@ -361,14 +400,9 @@ def _array_filter(t, op):
ops.NTile,
# ibis.expr.operations.strings
ops.Translate,
# ibis.expr.operations.maps
ops.MapGet,
ops.MapContains,
ops.MapKeys,
ops.MapValues,
ops.MapMerge,
ops.MapLength,
ops.Map,
# ibis.expr.operations.json
ops.ToJSONMap,
ops.ToJSONArray,
}

operation_registry = {
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/duckdb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


class TestConf(BackendTest, RoundAwayFromZero):
supports_map = True

def __init__(self, data_directory: Path, **kwargs: Any) -> None:
self.connection = self.connect(data_directory, **kwargs)

Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/pandas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ def convert_array_to_series(in_dtype, out_dtype, column):
return column.map(lambda x: list(x) if util.is_iterable(x) else x)


@sch.convert.register(np.dtype, dt.Map, pd.Series)
def convert_map_to_series(in_dtype, out_dtype, column):
return column.map(lambda x: dict(x) if util.is_iterable(x) else x)


@sch.convert.register(np.dtype, dt.JSON, pd.Series)
def convert_json_to_series(in_, out, col: pd.Series):
def try_json(x):
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pytest
from pytest import param

from ibis.common.exceptions import OperationNotDefinedError

pytestmark = [
pytest.mark.never(["impala"], reason="doesn't support JSON and never will"),
pytest.mark.notyet(["clickhouse"], reason="upstream is broken"),
Expand Down Expand Up @@ -43,12 +45,13 @@ def test_json_getitem(json_t, expr_fn, expected):
tm.assert_series_equal(result, expected)


@pytest.mark.notimpl(["dask", "duckdb", "mysql", "pandas"])
@pytest.mark.notimpl(["dask", "mysql", "pandas"])
@pytest.mark.notyet(["bigquery", "sqlite"], reason="doesn't support maps")
@pytest.mark.notyet(["postgres"], reason="only supports map<string, string>")
@pytest.mark.notyet(
["pyspark", "trino"], reason="should work but doesn't deserialize JSON"
)
@pytest.mark.notimpl(["duckdb"], raises=OperationNotDefinedError)
def test_json_map(json_t):
expr = json_t.js.map.name("res")
result = expr.execute()
Expand All @@ -67,12 +70,13 @@ def test_json_map(json_t):
tm.assert_series_equal(result, expected)


@pytest.mark.notimpl(["dask", "duckdb", "mysql", "pandas"])
@pytest.mark.notimpl(["dask", "mysql", "pandas"])
@pytest.mark.notyet(["sqlite"], reason="doesn't support arrays")
@pytest.mark.notyet(
["pyspark", "trino"], reason="should work but doesn't deserialize JSON"
)
@pytest.mark.notyet(["bigquery"], reason="doesn't allow null in arrays")
@pytest.mark.notimpl(["duckdb"], raises=OperationNotDefinedError)
def test_json_array(json_t):
expr = json_t.js.array.name("res")
result = expr.execute()
Expand Down
17 changes: 12 additions & 5 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib

import numpy as np
import pandas as pd
import pytest
from pytest import param

Expand All @@ -16,7 +17,7 @@
["bigquery", "impala"], reason="Backend doesn't yet implement map types"
),
pytest.mark.notimpl(
["duckdb", "datafusion", "pyspark", "polars", "druid"],
["datafusion", "pyspark", "polars", "druid"],
reason="Not yet implemented in ibis",
),
]
Expand Down Expand Up @@ -49,6 +50,7 @@ def test_literal_map_values(con):

@pytest.mark.notimpl(["trino", "postgres"])
@pytest.mark.notyet(["snowflake"])
@pytest.mark.notyet(["duckdb"], reason="sqlalchemy warning")
def test_scalar_isin_literal_map_keys(con):
mapping = ibis.literal({'a': 1, 'b': 2})
a = ibis.literal('a')
Expand Down Expand Up @@ -92,8 +94,11 @@ def test_map_column_contains_key_scalar(backend, alltypes, df):

@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
def test_map_column_contains_key_column(alltypes):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
result = expr.contains(alltypes.string_col).name('tmp').execute()
map_expr = ibis.map(
ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col])
)
expr = map_expr.contains(alltypes.string_col).name('tmp')
result = expr.execute()
assert result.all()


Expand Down Expand Up @@ -186,14 +191,16 @@ def test_map_get_with_null_on_not_nullable(con, null_value):
map_type = dt.Map(dt.string, dt.Int16(nullable=False))
value = ibis.literal({'A': 1000, 'B': 2000}).cast(map_type)
expr = value.get('C', null_value)
assert con.execute(expr) is None
result = con.execute(expr)
assert pd.isna(result)


@pytest.mark.parametrize('null_value', [None, ibis.NA])
def test_map_get_with_null_on_null_type_with_null(con, null_value):
value = ibis.literal({'A': None, 'B': None})
expr = value.get('C', null_value)
assert con.execute(expr) is None
result = con.execute(expr)
assert pd.isna(result)


@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_scalar_param_struct(con):


@pytest.mark.notimpl(
["clickhouse", "datafusion", "duckdb", "impala", "pyspark", "polars", "druid"]
["clickhouse", "datafusion", "impala", "pyspark", "polars", "druid"]
)
@pytest.mark.never(
["mysql", "sqlite", "mssql"],
Expand Down

0 comments on commit a4c4e77

Please sign in to comment.