Skip to content

Commit

Permalink
refactor(duckdb): use lambda to define backend operations
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysztof-kwitt authored and cpcloud committed Dec 31, 2022
1 parent b937391 commit 5d14de6
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,28 +107,13 @@ def _literal(_, op):
return sa.cast(sa.literal(value), sqla_type)


def _array_column(t, op):
(arg,) = op.args
sqla_type = to_sqla_type(op.output_dtype)
return sa.cast(sa.func.list_value(*map(t.translate, arg)), sqla_type)


def _neg_idx_to_pos(array, idx):
if_ = getattr(sa.func, "if")
arg_length = sa.func.array_length(array)
return if_(idx < 0, arg_length + sa.func.greatest(idx, -arg_length), idx)


def _struct_field(t, op):
return sa.func.struct_extract(
t.translate(op.arg),
sa.text(repr(op.field)),
type_=to_sqla_type(op.output_dtype),
)


def _regex_extract(t, op):
string, pattern, index = map(t.translate, op.args)
def _regex_extract(string, pattern, index):
result = sa.case(
[
(
Expand All @@ -149,8 +134,7 @@ def _regex_extract(t, op):
return result


def _json_get_item(t, op):
left, path = map(t.translate, op.args)
def _json_get_item(left, path):
# Workaround for https://github.com/duckdb/duckdb/issues/5063
# In some situations duckdb silently does the wrong thing if
# the path is parametrized.
Expand Down Expand Up @@ -197,7 +181,12 @@ def _struct_column(t, op):

operation_registry.update(
{
ops.ArrayColumn: _array_column,
ops.ArrayColumn: (
lambda t, op: sa.cast(
sa.func.list_value(*map(t.translate, op.cols)),
to_sqla_type(op.output_dtype),
)
),
ops.ArrayConcat: fixed_arity(sa.func.array_concat, 2),
ops.ArrayRepeat: fixed_arity(
lambda arg, times: sa.func.flatten(
Expand All @@ -222,7 +211,13 @@ def _struct_column(t, op):
# TODO: map operations, but DuckDB's maps are multimaps
ops.Modulus: fixed_arity(operator.mod, 2),
ops.Round: _round,
ops.StructField: _struct_field,
ops.StructField: (
lambda t, op: sa.func.struct_extract(
t.translate(op.arg),
sa.text(repr(op.field)),
type_=to_sqla_type(op.output_dtype),
)
),
ops.TableColumn: _table_column,
ops.TimestampDiff: fixed_arity(sa.func.age, 2),
ops.TimestampFromUNIX: _timestamp_from_unix,
Expand All @@ -232,7 +227,7 @@ def _struct_column(t, op):
lambda *_: sa.cast(sa.func.now(), sa.TIMESTAMP),
0,
),
ops.RegexExtract: _regex_extract,
ops.RegexExtract: fixed_arity(_regex_extract, 3),
ops.RegexReplace: fixed_arity(
lambda *args: sa.func.regexp_replace(*args, "g"), 3
),
Expand All @@ -255,7 +250,7 @@ def _struct_column(t, op):
ops.ArgMin: reduction(sa.func.min_by),
ops.ArgMax: reduction(sa.func.max_by),
ops.BitwiseXor: fixed_arity(sa.func.xor, 2),
ops.JSONGetItem: _json_get_item,
ops.JSONGetItem: fixed_arity(_json_get_item, 2),
ops.RowID: lambda *_: sa.literal_column('rowid'),
ops.StringToTimestamp: fixed_arity(sa.func.strptime, 2),
}
Expand Down

0 comments on commit 5d14de6

Please sign in to comment.