-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Prototype calling Python UDFs from rust (#559)
- Loading branch information
1 parent
01b0d6c
commit 44ef3f9
Showing
10 changed files
with
1,057 additions
and
5 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import functools | ||
import pandas as pd | ||
import pyarrow as pa | ||
from typing import Callable | ||
|
||
# TODO: Allow functions to return `pd.DataFrame` for struct arrays. | ||
FuncType = Callable[..., pd.Series] | ||
|
||
class Udf(object): | ||
def __init__(self, name, func: FuncType, signature: str) -> None: | ||
functools.update_wrapper(self, func) | ||
self.name = name | ||
self.func = func | ||
self.signature = signature | ||
|
||
def run_pyarrow(self, result_type: pa.DataType, *args: pa.Array) -> pa.Array: | ||
# TODO: I believe this will return a series for simple arrays, and a | ||
# dataframe for struct arrays. We should explore how this handles | ||
# different types. | ||
pd_args = [arg.to_pandas() for arg in args] | ||
pd_result = self.func(*pd_args) | ||
|
||
if isinstance(pd_result, pd.Series): | ||
return pa.Array.from_pandas(pd_result, type=result_type) | ||
else: | ||
raise TypeError(f'Unsupported result type: {type(pd_result)}') | ||
|
||
|
||
def fenl_udf(name: str, signature: str): | ||
def decorator(func: FuncType): | ||
print(type(func)) | ||
return Udf(name, func, signature) | ||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from sparrow_py.udf import Udf, fenl_udf | ||
from sparrow_py.ffi import call_udf | ||
import pandas as pd | ||
import pyarrow as pa | ||
|
||
@fenl_udf("add", "add(x: number, y: number) -> number") | ||
def add(x: pd.Series, y: pd.Series) -> pd.Series: | ||
return x + y | ||
|
||
def test_numeric_udf_pure_python() -> None: | ||
"""Test the python side of UDFs.""" | ||
assert type(add) == Udf | ||
|
||
x = pa.array([1, 12, 17, 23, 28], type=pa.int8()) | ||
y = pa.array([1, 13, 18, 20, 4], type=pa.int8()) | ||
result = add.run_pyarrow(pa.int8(), x, y) | ||
print(result) | ||
assert result == pa.array([2, 25, 35, 43, 32], type=pa.int8()) | ||
|
||
|
||
def test_numeric_udf_rust() -> None: | ||
"""Test the rust side of UDFs.""" | ||
x = pa.array([1, 12, 17, 23, 28], type=pa.int8()) | ||
y = pa.array([1, 13, 18, 20, 4], type=pa.int8()) | ||
result = call_udf(add, pa.int8(), x, y) | ||
print(result) | ||
assert result == pa.array([2, 25, 35, 43, 32], type=pa.int8()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use arrow::array::ArrayData; | ||
use arrow::datatypes::DataType; | ||
use arrow::pyarrow::{FromPyArrow, ToPyArrow}; | ||
use pyo3::prelude::*; | ||
use pyo3::types::PyTuple; | ||
|
||
#[pyfunction] | ||
#[pyo3(signature = (udf, result_type, *args))] | ||
pub(super) fn call_udf<'py>( | ||
py: Python<'py>, | ||
udf: &'py PyAny, | ||
result_type: &'py PyAny, | ||
args: &'py PyTuple, | ||
) -> PyResult<&'py PyAny> { | ||
let result_type = DataType::from_pyarrow(result_type)?; | ||
|
||
// 1. Make sure we can convert each input to and from arrow arrays. | ||
let mut udf_args = Vec::with_capacity(args.len() + 1); | ||
udf_args.push(result_type.to_pyarrow(py)?); | ||
for arg in args { | ||
let array_data = ArrayData::from_pyarrow(arg)?; | ||
let py_array: PyObject = array_data.to_pyarrow(py)?; | ||
udf_args.push(py_array); | ||
} | ||
let args = PyTuple::new(py, udf_args); | ||
let result = udf.call_method("run_pyarrow", args, None)?; | ||
|
||
let array_data: ArrayData = ArrayData::from_pyarrow(result)?; | ||
assert_eq!(array_data.data_type(), &result_type); | ||
|
||
Ok(result) | ||
} |