Skip to content

Commit

Permalink
feat(bigquery): implement array repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 26, 2022
1 parent 2402506 commit 09d1e2f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
15 changes: 12 additions & 3 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,17 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
return translate


def _array_repeat(t, op):
start = step = 1
times = t.translate(op.times)
arg = t.translate(op.arg)
array_length = f"ARRAY_LENGTH({arg})"
stop = f"GREATEST({times}, 0) * {array_length}"
idx = f"COALESCE(NULLIF(MOD(i, {array_length}), 0), {array_length})"
series = f"GENERATE_ARRAY({start}, {stop}, {step})"
return f"ARRAY(SELECT {arg}[SAFE_ORDINAL({idx})] FROM UNNEST({series}) AS i)"


OPERATION_REGISTRY = {
**operation_registry,
# Literal
Expand Down Expand Up @@ -546,12 +557,10 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
ops.ArrayConcat: _array_concat,
ops.ArrayIndex: _array_index,
ops.ArrayLength: unary("ARRAY_LENGTH"),
ops.ArrayRepeat: _array_repeat,
ops.HLLCardinality: reduction("APPROX_COUNT_DISTINCT"),
ops.Log: _log,
ops.Log2: _log2,
# BigQuery doesn't have these operations built in.
# ops.ArrayRepeat: _array_repeat,
# ops.ArraySlice: _array_slice,
ops.Arbitrary: _arbitrary,
# Geospatial Columnar
ops.GeoUnaryUnion: unary("ST_UNION_AGG"),
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ def test_array_scalar(con):
assert np.array_equal(result, expected)


@pytest.mark.notimpl(["impala", "snowflake", "polars", "datafusion"])
def test_array_repeat(con):
expr = ibis.array([1.0, 2.0]) * 2

result = con.execute(expr.name("tmp"))
expected = np.array([1.0, 2.0, 1.0, 2.0])

# This does not check whether `result` is an np.array or a list,
# because it varies across backends and backend configurations
assert np.array_equal(result, expected)


# Issues #2370
@pytest.mark.notimpl(["impala", "datafusion", "snowflake"])
def test_array_concat(con):
Expand Down

0 comments on commit 09d1e2f

Please sign in to comment.