Skip to content

Commit

Permalink
feat(bigquery): use sqlglot to implement functional unnest to relatio…
Browse files Browse the repository at this point in the history
…nal unnest
  • Loading branch information
cpcloud committed Sep 24, 2023
1 parent be861d9 commit 167c3bd
Show file tree
Hide file tree
Showing 13 changed files with 156 additions and 71 deletions.
51 changes: 49 additions & 2 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import google.cloud.bigquery_storage_v1 as bqstorage
import pandas as pd
import pydata_google_auth
import sqlglot as sg
from pydata_google_auth import cache

import ibis
Expand Down Expand Up @@ -74,6 +75,17 @@ def _create_client_info_gapic(application_name):
return ClientInfo(user_agent=_create_user_agent(application_name))


def _anonymous_unnest_to_explode(node: sg.exp.Expression):
"""Convert `ANONYMOUS` `unnest` function calls to `EXPLODE` calls.
This allows us to generate DuckDB-like `UNNEST` calls and let sqlglot do
the work of transforming those into the correct BigQuery SQL.
"""
if isinstance(node, sg.exp.Anonymous) and node.this.lower() == "unnest":
return sg.exp.Explode(this=node.expressions[0])
return node


class Backend(BaseSQLBackend, CanCreateSchema, CanListDatabases):
name = "bigquery"
compiler = BigQueryCompiler
Expand Down Expand Up @@ -313,6 +325,42 @@ def _execute(self, stmt, results=True, query_parameters=None):
query.result() # blocks until finished
return BigQueryCursor(query)

def compile(
self,
expr: ir.Expr,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
**_,
) -> Any:
"""Compile an Ibis expression.
Parameters
----------
expr
Ibis expression
limit
For expressions yielding result sets; retrieve at most this number
of values/rows. Overrides any limit already set on the expression.
params
Named unbound parameters
Returns
-------
Any
The output of compilation. The type of this value depends on the
backend.
"""

self._define_udf_translation_rules(expr)
sql = self.compiler.to_ast_ensure_limit(expr, limit, params=params).compile()

return ";\n\n".join(
query.transform(_anonymous_unnest_to_explode).sql(
dialect="bigquery", pretty=True
)
for query in sg.parse(sql, read="bigquery")
)

def raw_sql(self, query: str, results=False, params=None):
query_parameters = [
bigquery_param(
Expand Down Expand Up @@ -378,8 +426,7 @@ def execute(self, expr, params=None, limit="default", **kwargs):

# TODO: upstream needs to pass params to raw_sql, I think.
kwargs.pop("timecontext", None)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
sql = self.compile(expr, limit=limit, params=params, **kwargs)
self._log(sql)
cursor = self.raw_sql(sql, params=params, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ def _count_distinct_star(t, op):
ops.TableColumn: table_column,
ops.CountDistinctStar: _count_distinct_star,
ops.Argument: lambda _, op: op.name,
ops.Unnest: unary("UNNEST"),
}

_invalid_operations = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
SELECT
t0.`rowindex`,
IF(pos = pos_2, repeated_struct_col, NULL) AS repeated_struct_col
FROM array_test AS t0, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(t0.`repeated_struct_col`)) - 1)) AS pos
CROSS JOIN UNNEST(t0.`repeated_struct_col`) AS repeated_struct_col WITH OFFSET AS pos_2
WHERE
pos = pos_2
OR (
pos > (
ARRAY_LENGTH(t0.`repeated_struct_col`) - 1
)
AND pos_2 = (
ARRAY_LENGTH(t0.`repeated_struct_col`) - 1
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
SELECT
IF(pos = pos_2, level_two, NULL) AS level_two
FROM (
SELECT
t1.`rowindex`,
IF(pos = pos_2, level_one, NULL).`nested_struct_col` AS level_one
FROM array_test AS t1, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(t1.`repeated_struct_col`)) - 1)) AS pos
CROSS JOIN UNNEST(t1.`repeated_struct_col`) AS level_one WITH OFFSET AS pos_2
WHERE
pos = pos_2
OR (
pos > (
ARRAY_LENGTH(t1.`repeated_struct_col`) - 1
)
AND pos_2 = (
ARRAY_LENGTH(t1.`repeated_struct_col`) - 1
)
)
) AS t0, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(t0.`level_one`)) - 1)) AS pos
CROSS JOIN UNNEST(t0.`level_one`) AS level_two WITH OFFSET AS pos_2
WHERE
pos = pos_2
OR (
pos > (
ARRAY_LENGTH(t0.`level_one`) - 1
)
AND pos_2 = (
ARRAY_LENGTH(t0.`level_one`) - 1
)
)
47 changes: 39 additions & 8 deletions ibis/backends/bigquery/tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,7 @@ def test_timestamp_accepts_date_literals(alltypes):
expr = alltypes.mutate(param=p)
params = {p: date_string}
result = to_sql(expr, params=params)
expected = """\
SELECT t\\d+\\.\\*, @param_\\d+ AS `param`
FROM functional_alltypes t\\d+"""
assert re.match(expected, result) is not None
assert re.search(r"@param_\d+ AS `param`", result) is not None


@pytest.mark.parametrize("distinct", [True, False])
Expand Down Expand Up @@ -587,13 +584,47 @@ def test_scalar_param_scope(alltypes):
t = alltypes
param = ibis.param("timestamp")
result = to_sql(t.mutate(param=param), params={param: "2017-01-01"})
expected = """\
SELECT t\\d+\\.\\*, @param_\\d+ AS `param`
FROM functional_alltypes t\\d+"""
assert re.match(expected, result) is not None
assert re.search(r"@param_\d+ AS `param`", result) is not None


def test_cast_float_to_int(alltypes, snapshot):
expr = alltypes.double_col.cast("int64")
result = to_sql(expr)
snapshot.assert_match(result, "out.sql")


def test_unnest(snapshot):
table = ibis.table(
dict(
rowindex="int",
repeated_struct_col=dt.Array(
dt.Struct(
dict(
nested_struct_col=dt.Array(
dt.Struct(
dict(
doubly_nested_array="array<int>",
doubly_nested_field="string",
)
)
)
)
)
),
),
name="array_test",
)
repeated_struct_col = table.repeated_struct_col

# Works as expected :-)
result = ibis.bigquery.compile(
table.select("rowindex", repeated_struct_col.unnest())
)
snapshot.assert_match(result, "out_one_unnest.sql")

result = ibis.bigquery.compile(
table.select(
"rowindex", level_one=repeated_struct_col.unnest().nested_struct_col
).select(level_two=lambda t: t.level_one.unnest())
)
snapshot.assert_match(result, "out_two_unnests.sql")
5 changes: 0 additions & 5 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,11 +1261,6 @@ def test_topk_op(alltypes, df):
)
],
)
@mark.broken(
["bigquery"],
raises=GoogleBadRequest,
reason='400 Syntax error: Expected keyword JOIN but got identifier "SEMI"',
)
@mark.broken(
["druid"],
raises=sa.exc.ProgrammingError,
Expand Down
44 changes: 11 additions & 33 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import contextlib
import functools
import os

import numpy as np
import pandas as pd
Expand All @@ -11,7 +10,6 @@
import sqlalchemy as sa
import sqlglot as sg
import toolz
from packaging.version import parse as parse_version
from pytest import param

import ibis
Expand Down Expand Up @@ -185,20 +183,6 @@ def test_array_index(con, idx):
assert result == arr[idx]


duckdb_0_4_0 = pytest.mark.xfail(
(
# nixpkgs is patched to include the fix, so we pass these tests
# inside the nix-shell or when they run under `nix build`
(not any(key.startswith("NIX_") for key in os.environ))
and (
parse_version(getattr(duckdb, "__version__", "0.0.0"))
== parse_version("0.4.0")
)
),
reason="DuckDB array support is broken in 0.4.0 without nix",
)


builtin_array = toolz.compose(
# these will almost certainly never be supported
pytest.mark.never(
Expand All @@ -211,16 +195,6 @@ def test_array_index(con, idx):
),
# someone just needs to implement these
pytest.mark.notimpl(["datafusion"], raises=Exception),
duckdb_0_4_0,
)

unnest = toolz.compose(
builtin_array,
pytest.mark.notyet(
["bigquery"],
reason="doesn't support unnest in SELECT position",
raises=com.OperationNotDefinedError,
),
)


Expand Down Expand Up @@ -354,7 +328,12 @@ def test_array_discovery_snowflake(backend):
assert t.schema() == expected


@unnest
@builtin_array
@pytest.mark.notyet(
["bigquery"],
reason="BigQuery doesn't support casting array<T> to array<U>",
raises=BadRequest,
)
@pytest.mark.notimpl(["dask"], raises=ValueError)
def test_unnest_simple(backend):
array_types = backend.array_types
Expand All @@ -370,7 +349,7 @@ def test_unnest_simple(backend):
tm.assert_series_equal(result, expected)


@unnest
@builtin_array
@pytest.mark.notimpl("dask", raises=com.OperationNotDefinedError)
def test_unnest_complex(backend):
array_types = backend.array_types
Expand All @@ -396,7 +375,7 @@ def test_unnest_complex(backend):
tm.assert_frame_equal(result, expected)


@unnest
@builtin_array
@pytest.mark.never(
"pyspark",
reason="pyspark throws away nulls in collect_list",
Expand Down Expand Up @@ -426,7 +405,7 @@ def test_unnest_idempotent(backend):
tm.assert_frame_equal(result, expected)


@unnest
@builtin_array
@pytest.mark.notimpl("dask", raises=ValueError)
def test_unnest_no_nulls(backend):
array_types = backend.array_types
Expand All @@ -452,7 +431,7 @@ def test_unnest_no_nulls(backend):
tm.assert_frame_equal(result, expected)


@unnest
@builtin_array
@pytest.mark.notimpl("dask", raises=ValueError)
def test_unnest_default_name(backend):
array_types = backend.array_types
Expand Down Expand Up @@ -720,7 +699,6 @@ def test_array_intersect(con):
assert lhs == rhs, f"row {i:d} differs"


@unnest
@builtin_array
@pytest.mark.notimpl(
["clickhouse"],
Expand Down Expand Up @@ -767,7 +745,7 @@ def test_zip(backend):
assert len(x[0]) == len(s[0])


@unnest
@builtin_array
@pytest.mark.broken(
["clickhouse"],
raises=sg.ParseError,
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,11 +972,6 @@ def query(t, group_cols):

@pytest.mark.notimpl(["dask", "pandas", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.notyet(
["bigquery"],
reason="backend doesn't implement unnest",
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["datafusion", "impala", "mssql", "mysql", "sqlite"],
reason="backend doesn't support arrays and we don't implement pivot_longer with unions yet",
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_mutating_join(backend, batting, awards_players, how):


@pytest.mark.parametrize("how", ["semi", "anti"])
@pytest.mark.notimpl(["bigquery", "dask", "druid"])
@pytest.mark.notimpl(["dask", "druid"])
def test_filtering_join(backend, batting, awards_players, how):
left = batting[batting.yearID == 2015]
right = awards_players[awards_players.lgID == "NL"].drop("yearID", "lgID")
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_mutate_then_join_no_column_overlap(batting, awards_players):
assert not expr.limit(5).execute().empty


@pytest.mark.notimpl(["bigquery", "druid"])
@pytest.mark.notimpl(["druid"])
@pytest.mark.notyet(["dask"], reason="dask doesn't support descending order by")
@pytest.mark.broken(
["polars"],
Expand Down
9 changes: 1 addition & 8 deletions ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,7 @@ def test_isin_bug(con, snapshot):
["sqlite", "mysql", "druid", "impala", "mssql"], reason="no unnest support upstream"
)
@pytest.mark.notimpl(
["bigquery", "oracle"],
reason="unnest not yet implemented",
raises=exc.OperationNotDefinedError,
)
@pytest.mark.xfail_version(
duckdb=["sqlglot<=11.4.5"],
raises=sg.ParseError,
reason="https://github.com/tobymao/sqlglot/pull/1379 not in the installed version of sqlglot",
["oracle"], reason="unnest not yet implemented", raises=exc.OperationNotDefinedError
)
@pytest.mark.parametrize("backend_name", _get_backends_to_test())
def test_union_aliasing(backend_name, snapshot):
Expand Down
Loading

0 comments on commit 167c3bd

Please sign in to comment.