Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Prototype calling Python UDFs from rust #559

Merged
merged 2 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
791 changes: 788 additions & 3 deletions sparrow-py/Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions sparrow-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Python library for building and executing Kaskada queries.

[dependencies]
itertools = "0.11.0"
arrow = {version = "44.0.0", features = ["pyarrow"] }
pyo3 = {version = "0.19.1", features = ["abi3-py37", "extension-module", "generate-import-lib"]}

[lib]
Expand Down
168 changes: 167 additions & 1 deletion sparrow-py/poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions sparrow-py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ description = "Kaskada query builder and local execution engine."
authors = []

[tool.poetry.dependencies]
pandas = "^2.0.3"
python = ">=3.8,<4.0"

[tool.poetry.dev-dependencies]
Expand All @@ -20,6 +21,7 @@ furo = ">=2021.11.12"
isort = ">=5.10.1"
mypy = ">=0.930"
pep8-naming = ">=0.12.1"
pyarrow = "^12.0.1"
pytest = ">=6.2.5"
pyupgrade = ">=2.29.1"
safety = ">=1.10.3"
Expand Down
2 changes: 1 addition & 1 deletion sparrow-py/pysrc/sparrow_py/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, name: str, *args: "Expr") -> None:

Raises
------
TypeError # noqa: DAR402
TypeError
If the argument types are invalid for the given function.
"""
ffi_args = [arg.ffi for arg in args]
Expand Down
4 changes: 4 additions & 0 deletions sparrow-py/pysrc/sparrow_py/ffi.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Sequence
from sparrow_py.udf import Udf
import pyarrow as pa

class Session:
def __init__(self) -> None: ...
Expand All @@ -12,3 +14,5 @@ class Expr:
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def data_type_string(self) -> str: ...

def call_udf(udf: Udf, result_type: pa.DataType, *args: pa.Array) -> pa.Array: ...
33 changes: 33 additions & 0 deletions sparrow-py/pysrc/sparrow_py/udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import functools
import pandas as pd
import pyarrow as pa
from typing import Callable

# TODO: Allow functions to return `pd.DataFrame` for struct arrays.
FuncType = Callable[..., pd.Series]

class Udf(object):
def __init__(self, name, func: FuncType, signature: str) -> None:
functools.update_wrapper(self, func)
self.name = name
self.func = func
self.signature = signature

def run_pyarrow(self, result_type: pa.DataType, *args: pa.Array) -> pa.Array:
bjchambers marked this conversation as resolved.
Show resolved Hide resolved
# TODO: I believe this will return a series for simple arrays, and a
# dataframe for struct arrays. We should explore how this handles
# different types.
pd_args = [arg.to_pandas() for arg in args]
pd_result = self.func(*pd_args)

if isinstance(pd_result, pd.Series):
return pa.Array.from_pandas(pd_result, type=result_type)
else:
raise TypeError(f'Unsupported result type: {type(pd_result)}')


def fenl_udf(name: str, signature: str):
def decorator(func: FuncType):
print(type(func))
return Udf(name, func, signature)
return decorator
27 changes: 27 additions & 0 deletions sparrow-py/pytests/udf_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from sparrow_py.udf import Udf, fenl_udf
from sparrow_py.ffi import call_udf
import pandas as pd
import pyarrow as pa

@fenl_udf("add", "add(x: number, y: number) -> number")
def add(x: pd.Series, y: pd.Series) -> pd.Series:
return x + y

def test_numeric_udf_pure_python() -> None:
"""Test the python side of UDFs."""
assert type(add) == Udf

x = pa.array([1, 12, 17, 23, 28], type=pa.int8())
y = pa.array([1, 13, 18, 20, 4], type=pa.int8())
result = add.run_pyarrow(pa.int8(), x, y)
print(result)
assert result == pa.array([2, 25, 35, 43, 32], type=pa.int8())


def test_numeric_udf_rust() -> None:
"""Test the rust side of UDFs."""
x = pa.array([1, 12, 17, 23, 28], type=pa.int8())
y = pa.array([1, 13, 18, 20, 4], type=pa.int8())
result = call_udf(add, pa.int8(), x, y)
print(result)
assert result == pa.array([2, 25, 35, 43, 32], type=pa.int8())
2 changes: 2 additions & 0 deletions sparrow-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use pyo3::prelude::*;

mod expr;
mod session;
mod udf;

/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule]
fn ffi(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(udf::call_udf, m)?)?;
m.add_class::<session::Session>()?;
m.add_class::<expr::Expr>()?;

Expand Down
32 changes: 32 additions & 0 deletions sparrow-py/src/udf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use arrow::array::ArrayData;
use arrow::datatypes::DataType;
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use pyo3::prelude::*;
use pyo3::types::PyTuple;

#[pyfunction]
#[pyo3(signature = (udf, result_type, *args))]
pub(super) fn call_udf<'py>(
py: Python<'py>,
udf: &'py PyAny,
result_type: &'py PyAny,
args: &'py PyTuple,
) -> PyResult<&'py PyAny> {
let result_type = DataType::from_pyarrow(result_type)?;

// 1. Make sure we can convert each input to and from arrow arrays.
let mut udf_args = Vec::with_capacity(args.len() + 1);
udf_args.push(result_type.to_pyarrow(py)?);
for arg in args {
let array_data = ArrayData::from_pyarrow(arg)?;
let py_array: PyObject = array_data.to_pyarrow(py)?;
udf_args.push(py_array);
}
let args = PyTuple::new(py, udf_args);
let result = udf.call_method("run_pyarrow", args, None)?;

let array_data: ArrayData = ArrayData::from_pyarrow(result)?;
assert_eq!(array_data.data_type(), &result_type);

Ok(result)
}
Loading