Skip to content

Commit

Permalink
feat(postgres): implement array functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 5, 2023
1 parent a93cbd6 commit fe41d57
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 41 deletions.
64 changes: 64 additions & 0 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,11 +560,56 @@ def _array_sort(arg):
return sa.func.array(sa.select(flat).order_by(flat).scalar_subquery())


def _array_position(haystack, needle):
t = (
sa.func.unnest(haystack)
.table_valued("value", with_ordinality="idx", name="haystack")
.render_derived()
)
idx = t.c.idx - 1
return sa.func.coalesce(
sa.select(idx).where(t.c.value == needle).limit(1).scalar_subquery(), -1
)


def _array_map(t, op):
return sa.func.array(
# this translates to the function call, with column names the same as
# the parameter names in the lambda
sa.select(t.translate(op.result))
.select_from(
# unnest the input array
sa.func.unnest(t.translate(op.arg))
# name the columns of the result the same as the lambda parameter
# so that we can reference them as such in the outer query
.table_valued(op.parameter).render_derived()
)
.scalar_subquery()
)


def _array_filter(t, op):
param = op.parameter
return sa.func.array(
sa.select(
sa.column(param, type_=t.get_sqla_type(op.arg.output_dtype.value_type))
)
.select_from(
sa.func.unnest(t.translate(op.arg)).table_valued(param).render_derived()
)
.where(t.translate(op.result))
.scalar_subquery()
)


operation_registry.update(
{
ops.Literal: _literal,
# We override this here to support time zones
ops.TableColumn: _table_column,
ops.Argument: lambda t, op: sa.column(
op.name, type_=t.get_sqla_type(op.output_dtype)
),
# types
ops.TypeOf: _typeof,
# Floating
Expand Down Expand Up @@ -731,5 +776,24 @@ def _array_sort(arg):
),
2,
),
ops.ArrayRemove: fixed_arity(
lambda left, right: sa.func.array(
sa.except_(
sa.select(sa.func.unnest(left).column_valued()), sa.select(right)
).scalar_subquery()
),
2,
),
ops.ArrayDistinct: fixed_arity(
lambda arg: sa.func.array(
sa.select(
sa.distinct(sa.func.unnest(arg).column_valued())
).scalar_subquery()
),
1,
),
ops.ArrayPosition: fixed_arity(_array_position, 2),
ops.ArrayMap: _array_map,
ops.ArrayFilter: _array_filter,
}
)
45 changes: 4 additions & 41 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,7 @@ def test_array_slice(backend, start, stop):


@pytest.mark.notimpl(
[
"bigquery",
"datafusion",
"impala",
"mssql",
"polars",
"postgres",
"snowflake",
"sqlite",
],
["bigquery", "datafusion", "impala", "mssql", "polars", "snowflake", "sqlite"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
Expand Down Expand Up @@ -526,7 +517,6 @@ def test_array_map(backend, con):
"mssql",
"pandas",
"polars",
"postgres",
"snowflake",
],
raises=com.OperationNotDefinedError,
Expand Down Expand Up @@ -567,16 +557,7 @@ def test_array_contains(backend, con):


@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"impala",
"mssql",
"pandas",
"polars",
"postgres",
],
["bigquery", "dask", "datafusion", "impala", "mssql", "pandas", "polars"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
Expand All @@ -593,16 +574,7 @@ def test_array_position(backend, con):


@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"impala",
"mssql",
"pandas",
"polars",
"postgres",
],
["bigquery", "dask", "datafusion", "impala", "mssql", "pandas", "polars"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
Expand All @@ -619,16 +591,7 @@ def test_array_remove(backend, con):


@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"impala",
"mssql",
"pandas",
"polars",
"postgres",
],
["bigquery", "dask", "datafusion", "impala", "mssql", "pandas", "polars"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
Expand Down

0 comments on commit fe41d57

Please sign in to comment.