Skip to content

Commit

Permalink
feat: Prototype calling Python UDFs from rust (#559)
Browse files Browse the repository at this point in the history
  • Loading branch information
bjchambers authored Jul 26, 2023
1 parent 01b0d6c commit 44ef3f9
Show file tree
Hide file tree
Showing 10 changed files with 1,057 additions and 5 deletions.
791 changes: 788 additions & 3 deletions sparrow-py/Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions sparrow-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Python library for building and executing Kaskada queries.

[dependencies]
itertools = "0.11.0"
arrow = {version = "44.0.0", features = ["pyarrow"] }
pyo3 = {version = "0.19.1", features = ["abi3-py37", "extension-module", "generate-import-lib"]}

[lib]
Expand Down
168 changes: 167 additions & 1 deletion sparrow-py/poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions sparrow-py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ description = "Kaskada query builder and local execution engine."
authors = []

[tool.poetry.dependencies]
pandas = "^2.0.3"
python = ">=3.8,<4.0"

[tool.poetry.dev-dependencies]
Expand All @@ -20,6 +21,7 @@ furo = ">=2021.11.12"
isort = ">=5.10.1"
mypy = ">=0.930"
pep8-naming = ">=0.12.1"
pyarrow = "^12.0.1"
pytest = ">=6.2.5"
pyupgrade = ">=2.29.1"
safety = ">=1.10.3"
Expand Down
2 changes: 1 addition & 1 deletion sparrow-py/pysrc/sparrow_py/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, name: str, *args: "Expr") -> None:
Raises
------
TypeError # noqa: DAR402
TypeError
If the argument types are invalid for the given function.
"""
ffi_args = [arg.ffi for arg in args]
Expand Down
4 changes: 4 additions & 0 deletions sparrow-py/pysrc/sparrow_py/ffi.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Sequence
from sparrow_py.udf import Udf
import pyarrow as pa

class Session:
def __init__(self) -> None: ...
Expand All @@ -12,3 +14,5 @@ class Expr:
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def data_type_string(self) -> str: ...

def call_udf(udf: Udf, result_type: pa.DataType, *args: pa.Array) -> pa.Array: ...
33 changes: 33 additions & 0 deletions sparrow-py/pysrc/sparrow_py/udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import functools
import pandas as pd
import pyarrow as pa
from typing import Callable

# TODO: Allow functions to return `pd.DataFrame` for struct arrays.
FuncType = Callable[..., pd.Series]

class Udf(object):
def __init__(self, name, func: FuncType, signature: str) -> None:
functools.update_wrapper(self, func)
self.name = name
self.func = func
self.signature = signature

def run_pyarrow(self, result_type: pa.DataType, *args: pa.Array) -> pa.Array:
# TODO: I believe this will return a series for simple arrays, and a
# dataframe for struct arrays. We should explore how this handles
# different types.
pd_args = [arg.to_pandas() for arg in args]
pd_result = self.func(*pd_args)

if isinstance(pd_result, pd.Series):
return pa.Array.from_pandas(pd_result, type=result_type)
else:
raise TypeError(f'Unsupported result type: {type(pd_result)}')


def fenl_udf(name: str, signature: str):
def decorator(func: FuncType):
print(type(func))
return Udf(name, func, signature)
return decorator
27 changes: 27 additions & 0 deletions sparrow-py/pytests/udf_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from sparrow_py.udf import Udf, fenl_udf
from sparrow_py.ffi import call_udf
import pandas as pd
import pyarrow as pa

@fenl_udf("add", "add(x: number, y: number) -> number")
def add(x: pd.Series, y: pd.Series) -> pd.Series:
return x + y

def test_numeric_udf_pure_python() -> None:
"""Test the python side of UDFs."""
assert type(add) == Udf

x = pa.array([1, 12, 17, 23, 28], type=pa.int8())
y = pa.array([1, 13, 18, 20, 4], type=pa.int8())
result = add.run_pyarrow(pa.int8(), x, y)
print(result)
assert result == pa.array([2, 25, 35, 43, 32], type=pa.int8())


def test_numeric_udf_rust() -> None:
"""Test the rust side of UDFs."""
x = pa.array([1, 12, 17, 23, 28], type=pa.int8())
y = pa.array([1, 13, 18, 20, 4], type=pa.int8())
result = call_udf(add, pa.int8(), x, y)
print(result)
assert result == pa.array([2, 25, 35, 43, 32], type=pa.int8())
2 changes: 2 additions & 0 deletions sparrow-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use pyo3::prelude::*;

mod expr;
mod session;
mod udf;

/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule]
fn ffi(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(udf::call_udf, m)?)?;
m.add_class::<session::Session>()?;
m.add_class::<expr::Expr>()?;

Expand Down
32 changes: 32 additions & 0 deletions sparrow-py/src/udf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use arrow::array::ArrayData;
use arrow::datatypes::DataType;
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use pyo3::prelude::*;
use pyo3::types::PyTuple;

#[pyfunction]
#[pyo3(signature = (udf, result_type, *args))]
pub(super) fn call_udf<'py>(
py: Python<'py>,
udf: &'py PyAny,
result_type: &'py PyAny,
args: &'py PyTuple,
) -> PyResult<&'py PyAny> {
let result_type = DataType::from_pyarrow(result_type)?;

// 1. Make sure we can convert each input to and from arrow arrays.
let mut udf_args = Vec::with_capacity(args.len() + 1);
udf_args.push(result_type.to_pyarrow(py)?);
for arg in args {
let array_data = ArrayData::from_pyarrow(arg)?;
let py_array: PyObject = array_data.to_pyarrow(py)?;
udf_args.push(py_array);
}
let args = PyTuple::new(py, udf_args);
let result = udf.call_method("run_pyarrow", args, None)?;

let array_data: ArrayData = ArrayData::from_pyarrow(result)?;
assert_eq!(array_data.data_type(), &result_type);

Ok(result)
}

0 comments on commit 44ef3f9

Please sign in to comment.