Skip to content

Commit

Permalink
[FEAT]: hash expr (#2398)
Browse files Browse the repository at this point in the history
closes #1905
  • Loading branch information
universalmind303 authored Jun 20, 2024
1 parent 3aeba6f commit a2bb7f8
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 0 deletions.
1 change: 1 addition & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ class PyExpr:
def mean(self) -> PyExpr: ...
def min(self) -> PyExpr: ...
def max(self) -> PyExpr: ...
def hash(self, seed: Any | None = None) -> PyExpr: ...
def any_value(self, ignore_nulls: bool) -> PyExpr: ...
def agg_list(self) -> PyExpr: ...
def agg_concat(self) -> PyExpr: ...
Expand Down
10 changes: 10 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,16 @@ def between(self, lower: Any, upper: Any) -> Expression:
expr = self._expr.between(lower._expr, upper._expr)
return Expression._from_pyexpr(expr)

def hash(self, seed: Any | None = None) -> Expression:
"""Hashes the values in the Expression"""
if seed is None:
expr = self._expr.hash()
else:
if not isinstance(seed, Expression):
seed = lit(seed)
expr = self._expr.hash(seed._expr)
return Expression._from_pyexpr(expr)

def name(self) -> builtins.str:
return self._expr.name()

Expand Down
76 changes: 76 additions & 0 deletions src/daft-dsl/src/functions/hash.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use common_error::{DaftError, DaftResult};
use daft_core::{
datatypes::{Field, UInt64Array},
schema::Schema,
DataType, IntoSeries, Series,
};

use crate::{
functions::{FunctionEvaluator, FunctionExpr},
Expr, ExprRef,
};

pub(super) struct HashEvaluator {}

impl FunctionEvaluator for HashEvaluator {
fn fn_name(&self) -> &'static str {
"hash"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[input] | [input, _] => match input.to_field(schema) {
Ok(field) => Ok(Field::new(field.name, DataType::UInt64)),
e => e,
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input arg, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[input] => input.hash(None).map(|s| s.into_series()),
[input, seed] => {
match seed.len() {
1 => {
let seed = seed.cast(&DataType::UInt64)?;
// There's no way to natively extend the array, so we extract the element and repeat it.
let seed = seed.u64().unwrap();
let seed = seed.get(0).unwrap();
let seed = UInt64Array::from_iter(
"seed",
std::iter::repeat(Some(seed)).take(input.len()),
);
input.hash(Some(&seed)).map(|s| s.into_series())
}
_ if seed.len() == input.len() => {
let seed = seed.cast(&DataType::UInt64)?;
let seed = seed.u64().unwrap();

input.hash(Some(seed)).map(|s| s.into_series())
}
_ => Err(DaftError::ValueError(
"Seed must be a single value or the same length as the input".to_string(),
)),
}
}
_ => Err(DaftError::ValueError("Expected 2 input arg".to_string())),
}
}
}

pub fn hash(input: ExprRef, seed: Option<ExprRef>) -> ExprRef {
let inputs = match seed {
Some(seed) => vec![input, seed],
None => vec![input],
};

Expr::Function {
func: FunctionExpr::Hash,
inputs,
}
.into()
}
4 changes: 4 additions & 0 deletions src/daft-dsl/src/functions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod float;
pub mod hash;
pub mod image;
pub mod json;
pub mod list;
Expand Down Expand Up @@ -29,6 +30,7 @@ use self::{float::FloatExpr, uri::UriExpr};
use common_error::DaftResult;
use daft_core::datatypes::FieldID;
use daft_core::{datatypes::Field, schema::Schema, series::Series};
use hash::HashEvaluator;
use serde::{Deserialize, Serialize};

#[cfg(feature = "python")]
Expand All @@ -52,6 +54,7 @@ pub enum FunctionExpr {
Python(PythonUDF),
Partitioning(PartitioningExpr),
Uri(UriExpr),
Hash,
}

pub trait FunctionEvaluator {
Expand Down Expand Up @@ -84,6 +87,7 @@ impl FunctionExpr {
#[cfg(feature = "python")]
Python(expr) => expr,
Partitioning(expr) => expr.get_evaluator(),
Hash => &HashEvaluator {},
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,11 @@ impl PyExpr {
)
.into())
}

pub fn hash(&self, seed: Option<PyExpr>) -> PyResult<Self> {
use crate::functions::hash::hash;
Ok(hash(self.into(), seed.map(|s| s.into())).into())
}
}

impl_bincode_py_state_serialization!(PyExpr);
Expand Down
18 changes: 18 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,24 @@ def test_repr_functions_sqrt() -> None:
assert repr_out == repr(copied)


def test_repr_functions_hash() -> None:
a = col("a")
y = a.hash()
repr_out = repr(y)
assert repr_out == "hash(col(a))"
copied = copy.deepcopy(y)
assert repr_out == repr(copied)


def test_repr_functions_hash_2() -> None:
a = col("a")
y = a.hash(lit(1))
repr_out = repr(y)
assert repr_out == "hash(col(a), lit(1))"
copied = copy.deepcopy(y)
assert repr_out == repr(copied)


def test_expr_structurally_equal() -> None:
e1 = (col("a").max() == col("b").alias("moo") - 3).is_null()
e2 = (col("a").max() == col("b").alias("moo") - 3).is_null()
Expand Down
44 changes: 44 additions & 0 deletions tests/table/test_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import daft
from daft import col


def test_table_expr_hash():
df = daft.from_pydict(
{
"utf8": ["foo", "bar"],
"int": [1, None],
}
)
expected = {
"utf8": [12352915711150947722, 15304296276065178466],
"int": [3439722301264460078, 3244421341483603138],
}
result = df.select(col("utf8").hash(), col("int").hash())
assert result.to_pydict() == expected


def test_table_expr_hash_with_seed():
df = daft.from_pydict(
{
"utf8": ["foo", "bar"],
"int": [1, None],
}
)
expected = {
"utf8": [15221504070560512414, 2671805001252040144],
"int": [16405722695416140795, 3244421341483603138],
}
result = df.select(col("utf8").hash(seed=42), col("int").hash(seed=42))
assert result.to_pydict() == expected


def test_table_expr_hash_with_seed_array():
df = daft.from_pydict(
{
"utf8": ["foo", "bar"],
"seed": [1, 1000],
}
)
expected = {"utf8": [6076897603942036120, 15438169081903732554]}
result = df.select(col("utf8").hash(seed=col("seed")))
assert result.to_pydict() == expected

0 comments on commit a2bb7f8

Please sign in to comment.