Skip to content

Commit

Permalink
feat(snowflake): make literal maps and params work
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jan 18, 2023
1 parent 045edc7 commit dd759d3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 17 deletions.
18 changes: 15 additions & 3 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import itertools
import json

import numpy as np
import sqlalchemy as sa
from snowflake.sqlalchemy.custom_types import VARIANT
from snowflake.sqlalchemy import VARIANT

import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy.registry import (
Expand Down Expand Up @@ -52,7 +53,7 @@ def _literal(t, op):
return sa.func.array_construct(*value)
elif dtype.is_map():
return sa.func.object_construct_keep_null(
*zip(itertools.chain.from_iterable(value.items()))
*itertools.chain.from_iterable(value.items())
)
return _postgres_literal(t, op)

Expand Down Expand Up @@ -112,6 +113,17 @@ def _array_slice(t, op):
return sa.func.array_slice(t.translate(op.arg), start, stop)


def _map(_, op):
if not (
isinstance(keys := op.keys, ops.Literal)
and isinstance(values := op.values, ops.Literal)
):
raise TypeError("Both keys and values of an `ibis.map` call must be literals")

obj = dict(zip(keys.value, values.value))
return sa.func.to_object(sa.func.parse_json(json.dumps(obj, separators=",:")))


_SF_POS_INF = sa.cast(sa.literal("Inf"), sa.FLOAT)
_SF_NEG_INF = -_SF_POS_INF
_SF_NAN = sa.cast(sa.literal("NaN"), sa.FLOAT)
Expand Down Expand Up @@ -201,7 +213,7 @@ def _array_slice(t, op):
ops.ArrayColumn: lambda t, op: sa.func.array_construct(
*map(t.translate, op.cols)
),
ops.ArraySlice: _array_slice,
ops.Map: _map,
}
)

Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,16 @@ def test_literal_map_get_broadcast(backend, alltypes, df):
backend.assert_series_equal(result, expected)


@pytest.mark.notyet(["snowflake"])
def test_map_construction(con, alltypes, df):
def test_map_construct_dict(con):
expr = ibis.map(['a', 'b'], [1, 2])
result = con.execute(expr.name('tmp'))
assert result == {'a': 1, 'b': 2}


@pytest.mark.notimpl(
["snowflake"], reason="unclear how to implement two arrays -> object construction"
)
def test_map_construct_array_column(con, alltypes, df):
expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]))
result = con.execute(expr)
expected = df.apply(lambda row: {row['string_col']: row['int_col']}, axis=1)
Expand Down
14 changes: 2 additions & 12 deletions ibis/backends/tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_timestamp_accepts_date_literals(alltypes):
assert expr.compile(params=params) is not None


@pytest.mark.notimpl(["dask", "datafusion", "impala", "pandas", "pyspark", "snowflake"])
@pytest.mark.notimpl(["dask", "datafusion", "impala", "pandas", "pyspark"])
@pytest.mark.never(
["mysql", "sqlite", "mssql"], reason="backend will never implement array types"
)
Expand All @@ -84,17 +84,7 @@ def test_scalar_param_struct(con):


@pytest.mark.notimpl(
[
"clickhouse",
"datafusion",
# TODO: duckdb maps are tricky because they are multimaps
"duckdb",
"impala",
"pyspark",
"snowflake",
"polars",
"trino",
]
["clickhouse", "datafusion", "duckdb", "impala", "pyspark", "polars", "trino"]
)
@pytest.mark.never(
["mysql", "sqlite", "mssql"],
Expand Down

0 comments on commit dd759d3

Please sign in to comment.