Skip to content

Commit

Permalink
feat(pyspark, duckdb): add hexdigest support
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth committed Jan 26, 2024
1 parent c16977d commit 531d250
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
11 changes: 11 additions & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,16 @@ def _array_remove(t, op):
)


def _hexdigest(translator, op):
how = op.how

arg_formatted = translator.translate(op.arg)
if how in ("md5", "sha256"):
return getattr(sa.func, how)(arg_formatted)
else:
raise NotImplementedError(how)


operation_registry.update(
{
ops.Array: (
Expand Down Expand Up @@ -533,6 +543,7 @@ def _array_remove(t, op):
ops.MapValues: unary(sa.func.map_values),
ops.MapMerge: fixed_arity(sa.func.map_concat, 2),
ops.Hash: unary(sa.func.hash),
ops.HexDigest: _hexdigest,
ops.Median: reduction(sa.func.median),
ops.First: reduction(sa.func.first),
ops.Last: reduction(sa.func.last),
Expand Down
15 changes: 15 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,6 +2065,21 @@ def compile_hash_column(t, op, **kwargs):
return F.hash(t.translate(op.arg, **kwargs))


@compiles(ops.HexDigest)
def compile_hexdigest_column(t, op, **kwargs):
how = op.how
arg = t.translate(op.arg, **kwargs)

if how == "md5":
return F.md5(arg)

Check warning on line 2074 in ibis/backends/pyspark/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/compiler.py#L2074

Added line #L2074 was not covered by tests
elif how == "sha1":
return F.sha1(arg)

Check warning on line 2076 in ibis/backends/pyspark/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/compiler.py#L2076

Added line #L2076 was not covered by tests
elif how in ("sha256", "sha512"):
return F.sha2(arg, int(how[-3:]))
else:
raise NotImplementedError(how)


@compiles(ops.ArrayZip)
def compile_zip(t, op, **kwargs):
return F.arrays_zip(*map(partial(t.translate, **kwargs), op.arg))
Expand Down

0 comments on commit 531d250

Please sign in to comment.