Skip to content

Commit

Permalink
feat(sqlalchemy): generalize unnest to work on backends that don't su…
Browse files Browse the repository at this point in the history
…pport it
  • Loading branch information
cpcloud committed Feb 1, 2023
1 parent 80b05d8 commit 5943ce7
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 28 deletions.
37 changes: 29 additions & 8 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import functools

import sqlalchemy as sa
import toolz
from sqlalchemy import sql

import ibis.expr.analysis as an
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy.database import AlchemyTable
from ibis.backends.base.sql.alchemy.translator import (
Expand Down Expand Up @@ -207,20 +209,21 @@ def _compile_table_set(self):
def _add_select(self, table_set):
to_select = []

context = self.context
select_set = self.select_set

has_select_star = False
for op in self.select_set:
for op in select_set:
if isinstance(op, ops.Value):
arg = self._translate(op, named=True)
elif isinstance(op, ops.TableNode):
arg = context.get_ref(op)
if op.equals(self.table_set):
cached_table = self.context.get_ref(op)
if cached_table is None:
has_select_star = True
if has_select_star := arg is None:
continue
else:
arg = table_set
else:
arg = self.context.get_ref(op)
if arg is None:
raise ValueError(op)
else:
Expand All @@ -242,11 +245,29 @@ def _add_select(self, table_set):
if self.distinct:
result = result.distinct()

# if we're SELECT *-ing or there's no table_set (e.g., SELECT 1) then
# we can return early
if has_select_star or table_set is None:
# only process unnest if the backend doesn't support SELECT UNNEST(...)
unnest_children = []
if not self.translator_class.supports_unnest_in_select:
unnest_children.extend(
map(
context.get_ref,
toolz.unique(an.find_toplevel_unnest_children(select_set)),
)
)

# if we're SELECT *-ing or there's no table_set (e.g., SELECT 1) *and*
# there are no unnest operations then we can return early
if (has_select_star or table_set is None) and not unnest_children:
return result

if unnest_children:
# get all the unnests plus the current froms of the result selection
# and build up the cross join
table_set = functools.reduce(
functools.partial(sa.sql.FromClause.join, onclause=True),
toolz.unique(toolz.concatv(unnest_children, result.get_final_froms())),
)

return result.select_from(table_set)

def _add_group_by(self, fragment):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/base/sql/alchemy/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class AlchemyExprTranslator(ExprTranslator):

_dialect_name = "default"

supports_unnest_in_select = True

@functools.cached_property
def dialect(self) -> sa.engine.interfaces.Dialect:
if (name := self._dialect_name) == "default":
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class SnowflakeExprTranslator(AlchemyExprTranslator):
)
_require_order_by = (*AlchemyExprTranslator._require_order_by, ops.Reduction)
_dialect_name = "snowflake"
supports_unnest_in_select = False


class SnowflakeCompiler(AlchemyCompiler):
Expand Down
33 changes: 32 additions & 1 deletion ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,31 @@ def _arbitrary(t, op):
return t._reduction(sa.func.min, op)


class flatten(GenericFunction):
def __init__(self, arg, *, type: sa.types.TypeEngine) -> None:
super().__init__(arg)
self.type = sqltypes.TableValueType(sa.Column("value", type))


@compiles(flatten, "snowflake")
def compiles_flatten(element, compiler, **kw):
arg = compiler.function_argspec(element, **kw)
return f"FLATTEN(INPUT => {arg}, MODE => 'ARRAY')"


def _unnest(t, op):
arg = t.translate(op.arg)
# HACK: https://community.snowflake.com/s/question/0D50Z000086MVhnSAG/has-anyone-found-a-way-to-unnest-an-array-without-loosing-the-null-values
sep = util.guid()
sqla_type = t.get_sqla_type(op.output_dtype)
col = (
flatten(sa.func.split(sa.func.array_to_string(arg, sep), sep), type=sqla_type)
.lateral()
.c["value"]
)
return sa.cast(sa.func.nullif(col, ""), type_=sqla_type)


_TIMESTAMP_UNITS_TO_SCALE = {"s": 0, "ms": 3, "us": 6, "ns": 9}

_SF_POS_INF = sa.func.to_double("Inf")
Expand Down Expand Up @@ -259,6 +284,13 @@ def _arbitrary(t, op):
ops.StructColumn: lambda t, op: sa.func.object_construct_keep_null(
*itertools.chain.from_iterable(zip(op.names, map(t.translate, op.values)))
),
ops.Unnest: unary(
lambda arg: (
sa.func.table(sa.func.flatten(arg))
.table_valued("value")
.columns["value"]
)
),
}
)

Expand All @@ -270,7 +302,6 @@ def _arbitrary(t, op):
ops.NTile,
# ibis.expr.operations.array
ops.ArrayRepeat,
ops.Unnest,
# ibis.expr.operations.reductions
ops.MultiQuantile,
# ibis.expr.operations.strings
Expand Down
43 changes: 25 additions & 18 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,6 @@ def test_array_index(con, idx):
["mysql", "sqlite"],
reason="array types are unsupported",
),
pytest.mark.never(
["snowflake"],
reason="snowflake has an extremely specialized way of implementing arrays",
),
# someone just needs to implement these
pytest.mark.notimpl(["datafusion", "dask"]),
duckdb_0_4_0,
Expand All @@ -142,8 +138,7 @@ def test_array_index(con, idx):
builtin_array,
pytest.mark.notimpl(["pandas"]),
pytest.mark.notyet(
["bigquery", "snowflake", "trino"],
reason="doesn't support unnest in SELECT position",
["bigquery"], reason="doesn't support unnest in SELECT position"
),
)

Expand All @@ -153,6 +148,10 @@ def test_array_index(con, idx):
["clickhouse", "duckdb", "pandas", "pyspark", "snowflake", "polars"],
reason="backend does not flatten array types",
)
@pytest.mark.never(
["snowflake"],
reason="snowflake has an extremely specialized way of implementing arrays",
)
@pytest.mark.never(["bigquery"], reason="doesn't support arrays of arrays")
def test_array_discovery_postgres(con):
t = con.table("array_types")
Expand All @@ -170,6 +169,10 @@ def test_array_discovery_postgres(con):


@builtin_array
@pytest.mark.never(
["snowflake"],
reason="snowflake has an extremely specialized way of implementing arrays",
)
@pytest.mark.never(
["duckdb", "pandas", "postgres", "pyspark", "snowflake", "polars", "trino"],
reason="backend supports nullable nested types",
Expand Down Expand Up @@ -202,6 +205,10 @@ def test_array_discovery_clickhouse(con):
reason="trino supports nested arrays, but not with the postgres connector",
)
@pytest.mark.never(["bigquery"], reason="doesn't support arrays of arrays")
@pytest.mark.never(
["snowflake"],
reason="snowflake has an extremely specialized way of implementing arrays",
)
def test_array_discovery_desired(con):
t = con.table("array_types")
expected = ibis.schema(
Expand Down Expand Up @@ -259,7 +266,7 @@ def test_unnest_simple(con):
.astype("float64")
.rename("tmp")
)
expr = array_types.x.unnest()
expr = array_types.x.cast("!array<float64>").unnest()
result = expr.execute().rename("tmp")
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -291,26 +298,24 @@ def test_unnest_complex(con):


@unnest
@pytest.mark.never(
"pyspark",
reason="pyspark throws away nulls in collect_list",
)
@pytest.mark.never(
"clickhouse",
reason="clickhouse throws away nulls in groupArray",
)
@pytest.mark.never("pyspark", reason="pyspark throws away nulls in collect_list")
@pytest.mark.never("clickhouse", reason="clickhouse throws away nulls in groupArray")
@pytest.mark.notimpl("polars")
def test_unnest_idempotent(con):
array_types = con.table("array_types")
df = array_types.execute()
expr = (
array_types.select(["scalar_column", array_types.x.unnest().name("x")])
array_types.select(
["scalar_column", array_types.x.cast("!array<int64>").unnest().name("x")]
)
.group_by("scalar_column")
.aggregate(x=lambda t: t.x.collect())
.order_by("scalar_column")
)
result = expr.execute()
expected = df[["scalar_column", "x"]]
expected = (
df[["scalar_column", "x"]].sort_values("scalar_column").reset_index(drop=True)
)
tm.assert_frame_equal(result, expected)


Expand All @@ -320,7 +325,9 @@ def test_unnest_no_nulls(con):
array_types = con.table("array_types")
df = array_types.execute()
expr = (
array_types.select(["scalar_column", array_types.x.unnest().name("y")])
array_types.select(
["scalar_column", array_types.x.cast("!array<int64>").unnest().name("y")]
)
.filter(lambda t: t.y.notnull())
.group_by("scalar_column")
.aggregate(x=lambda t: t.y.collect())
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TrinoSQLExprTranslator(AlchemyExprTranslator):
ops.Lead,
)
_dialect_name = "trino"
supports_unnest_in_select = False


rewrites = TrinoSQLExprTranslator.rewrites
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ def _round(t, op):
return sa.func.round(arg)


def _unnest(t, op):
arg = op.arg
name = arg.name
return sa.func.unnest(t.translate(arg)).table_valued(name).render_derived().c[name]


operation_registry.update(
{
# conditional expressions
Expand Down Expand Up @@ -300,6 +306,7 @@ def _round(t, op):
)
),
ops.TypeOf: unary(sa.func.typeof),
ops.Unnest: _unnest,
}
)

Expand Down
12 changes: 11 additions & 1 deletion ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import operator
from collections import Counter
from typing import Iterator, Mapping
from typing import Iterable, Iterator, Mapping

import toolz

Expand Down Expand Up @@ -867,3 +867,13 @@ def finder(node):
return g.proceed, node if isinstance(node, ops.InMemoryTable) else None

return g.traverse(finder, node, filter=ops.Node)


def find_toplevel_unnest_children(nodes: Iterable[ops.Node]) -> Iterator[ops.Table]:
def finder(node):
return (
isinstance(node, ops.Value),
find_first_base_table(node) if isinstance(node, ops.Unnest) else None,
)

return g.traverse(finder, nodes, filter=ops.Node)

0 comments on commit 5943ce7

Please sign in to comment.