Skip to content

Commit

Permalink
feat(mssql): add hashbytes and test for binary output hash fns (#8107)
Browse files Browse the repository at this point in the history
## Description of changes

This adds support for `ops.HashBytes` to `mssql` and also adds a test
for that functionality so it's easier to port when we merge in the epic
split branch.

I've also added a new op, `HashHexDigest` which returns the hexdigest of
various cryptographic hashing functions since I imagine this is what
many users are _probably_ after. This newer op (and corresponding
`hexdigest` method) can also support many more backends, as most of them
default to returning the string hex digest and not the raw binary
output.

I tried to be very accurate in the `notimpl` and `notyet` portions of
both tests and I think I've done that.

For now, only exposing DuckDB, Pyspark, and MSSQL so we don't add a huge
extra burden for the epic split but also address the user request in
#8082

And I guess now we can commence debate over the method name? 🐎

## Issues closed

Resolves #8082

---------

Co-authored-by: Jim Crist-Harif <jcristharif@gmail.com>
  • Loading branch information
gforsyth and jcrist authored Jan 29, 2024
1 parent e087826 commit 91f60cd
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def _hash(op, *, arg, **_):

@translate_val.register(ops.HashBytes)
def _hash_bytes(op, *, arg, how, **_):
if how in ("md5", "sha1", "sha224", "sha256"):
how = how.upper()
if how not in _SUPPORTED_ALGORITHMS:
raise com.UnsupportedOperationError(f"Unsupported hash algorithm {how}")

Expand Down
11 changes: 11 additions & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,16 @@ def _array_remove(t, op):
)


def _hexdigest(translator, op):
how = op.how

arg_formatted = translator.translate(op.arg)
if how in ("md5", "sha256"):
return getattr(sa.func, how)(arg_formatted)
else:
raise NotImplementedError(how)


operation_registry.update(
{
ops.Array: (
Expand Down Expand Up @@ -533,6 +543,7 @@ def _array_remove(t, op):
ops.MapValues: unary(sa.func.map_values),
ops.MapMerge: fixed_arity(sa.func.map_concat, 2),
ops.Hash: unary(sa.func.hash),
ops.HexDigest: _hexdigest,
ops.Median: reduction(sa.func.median),
ops.First: reduction(sa.func.first),
ops.Last: reduction(sa.func.last),
Expand Down
43 changes: 43 additions & 0 deletions ibis/backends/mssql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,47 @@ def _literal(_, op):
return sa.literal(value)


def _hashbytes(translator, op):
how = op.how

arg_formatted = translator.translate(op.arg)

if how in ("md5", "sha1"):
return sa.func.hashbytes(how, arg_formatted)
elif how == "sha256":
return sa.func.hashbytes("sha2_256", arg_formatted)
elif how == "sha512":
return sa.func.hashbytes("sha2_512", arg_formatted)
else:
raise NotImplementedError(how)


def _hexdigest(translator, op):
# SO post on getting convert to play nice with VARCHAR in Sqlalchemy
# https://stackoverflow.com/questions/20291962/how-to-use-convert-function-in-sqlalchemy
how = op.how

arg_formatted = translator.translate(op.arg)
if how in ("md5", "sha1"):
hashbinary = sa.func.hashbytes(how, arg_formatted)
elif how == "sha256":
hashbinary = sa.func.hashbytes("sha2_256", arg_formatted)
elif how == "sha512":
hashbinary = sa.func.hashbytes("sha2_512", arg_formatted)
else:
raise NotImplementedError(how)

# mssql uppercases the hexdigest which is inconsistent with several other
# implementations and inconsistent with Python, so lowercase it.
return sa.func.lower(
sa.func.convert(
sa.literal_column("VARCHAR(MAX)"),
hashbinary,
2, # 2 means strip off leading '0x'
)
)


operation_registry = sqlalchemy_operation_registry.copy()
operation_registry.update(sqlalchemy_window_functions_registry)

Expand Down Expand Up @@ -316,6 +357,8 @@ def _literal(_, op):
ops.DateTruncate: _timestamp_truncate,
ops.TimestampBucket: _timestamp_bucket,
ops.Hash: unary(sa.func.checksum),
ops.HashBytes: _hashbytes,
ops.HexDigest: _hexdigest,
ops.ExtractMicrosecond: fixed_arity(
lambda arg: sa.func.datepart(sa.literal_column("microsecond"), arg), 1
),
Expand Down
15 changes: 15 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,6 +2065,21 @@ def compile_hash_column(t, op, **kwargs):
return F.hash(t.translate(op.arg, **kwargs))


@compiles(ops.HexDigest)
def compile_hexdigest_column(t, op, **kwargs):
how = op.how
arg = t.translate(op.arg, **kwargs)

if how == "md5":
return F.md5(arg)
elif how == "sha1":
return F.sha1(arg)
elif how in ("sha256", "sha512"):
return F.sha2(arg, int(how[-3:]))
else:
raise NotImplementedError(how)


@compiles(ops.ArrayZip)
def compile_zip(t, op, **kwargs):
return F.arrays_zip(*map(partial(t.translate, **kwargs), op.arg))
Expand Down
70 changes: 70 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,76 @@ def test_hash_consistent(backend, alltypes):
assert h1.dtype in ("i8", "uint64") # polars likes returning uint64 for this


@pytest.mark.notimpl(["trino", "oracle", "exasol", "snowflake"])
@pytest.mark.notyet(
[
"dask",
"datafusion",
"druid",
"duckdb",
"flink",
"impala",
"mysql",
"pandas",
"polars",
"postgres",
"pyspark",
"sqlite",
]
)
def test_hashbytes(backend, alltypes):
h1 = alltypes.order_by("id").string_col.hashbytes().execute(limit=10)
df = alltypes.order_by("id").execute(limit=10)

import hashlib

def hash_256(col):
return hashlib.sha256(col.encode()).digest()

h2 = df["string_col"].apply(hash_256).rename("HashBytes(string_col)")

backend.assert_series_equal(h1, h2)


@pytest.mark.notimpl(
[
"bigquery",
"clickhouse",
"dask",
"datafusion",
"exasol",
"flink",
"impala",
"mysql",
"oracle",
"pandas",
"polars",
"postgres",
"snowflake",
"trino",
]
)
@pytest.mark.notyet(
[
"druid",
"polars",
"sqlite",
]
)
def test_hexdigest(backend, alltypes):
h1 = alltypes.order_by("id").string_col.hexdigest().execute(limit=10)
df = alltypes.order_by("id").execute(limit=10)

import hashlib

def hash_256(col):
return hashlib.sha256(col.encode()).hexdigest()

h2 = df["string_col"].apply(hash_256).rename("HexDigest(string_col)")

backend.assert_series_equal(h1, h2)


@pytest.mark.notimpl(
[
"pandas",
Expand Down
14 changes: 14 additions & 0 deletions ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,20 @@ class HashBytes(Value):
shape = rlz.shape_like("arg")


@public
class HexDigest(Value):
arg: Value[dt.String | dt.Binary]
how: LiteralType[
"md5",
"sha1",
"sha256",
"sha512",
]

dtype = dt.str
shape = rlz.shape_like("arg")


# TODO(kszucs): we should merge the case operations by making the
# cases, results and default optional arguments like they are in
# api.py
Expand Down
34 changes: 34 additions & 0 deletions ibis/expr/types/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,40 @@ def hashbytes(
"""
return ops.HashBytes(self, how).to_expr()

def hexdigest(
self,
how: Literal["md5", "sha1", "sha256", "sha512"] = "sha256",
) -> ir.StringValue:
"""Return the hash digest of the input as a hex encoded string.
Parameters
----------
how
Hash algorithm to use
Returns
-------
StringValue
Hexadecimal representation of the hash as a string
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"species": ["Adelie", "Chinstrap", "Gentoo"]})
>>> t.species.hexdigest()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ HexDigest(species) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ string │
├──────────────────────────────────────────────────────────────────┤
│ a4d7d46b27480037bc1e513e0e157cbf258baae6ee69e3110d0f9ff418b57a3c │
│ cb97d113ca69899ae4f1fb581f4a90d86989db77b4a33873d604b0ee412b4cc9 │
│ b5e90cdff65949fe6bc226823245f7698110e563a12363fc57b3eed3e4a0a612 │
└──────────────────────────────────────────────────────────────────┘
"""
return ops.HexDigest(self, how.lower()).to_expr()

def substr(
self,
start: int | ir.IntegerValue,
Expand Down

0 comments on commit 91f60cd

Please sign in to comment.