Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Make expressions containing Python UDFs serializable #18135

Merged
merged 6 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 2 additions & 74 deletions crates/polars-plan/src/client/check.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
use polars_core::error::{polars_err, PolarsResult};
use polars_io::path_utils::is_cloud_url;

use crate::dsl::Expr;
use crate::plans::options::SinkType;
use crate::plans::{DslFunction, DslPlan, FileScan, FunctionIR};
use crate::plans::{DslPlan, FileScan};

/// Assert that the given [`DslPlan`] is eligible to be executed on Polars Cloud.
pub(super) fn assert_cloud_eligible(dsl: &DslPlan) -> PolarsResult<()> {
let mut expr_stack = vec![];
for plan_node in dsl.into_iter() {
match plan_node {
DslPlan::MapFunction { function, .. } => match function {
DslFunction::FunctionIR(FunctionIR::Opaque { .. }) => {
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
return ineligible_error("contains opaque function")
},
#[cfg(feature = "python")]
DslFunction::OpaquePython { .. } => {
return ineligible_error("contains Python function")
},
_ => (),
},
#[cfg(feature = "python")]
DslPlan::PythonScan { .. } => return ineligible_error("contains Python scan"),
DslPlan::GroupBy { apply: Some(_), .. } => {
return ineligible_error("contains Python function in group by operation")
},
DslPlan::Scan { paths, .. }
if paths.lock().unwrap().0.iter().any(|p| !is_cloud_url(p)) =>
{
Expand All @@ -39,23 +24,7 @@ pub(super) fn assert_cloud_eligible(dsl: &DslPlan) -> PolarsResult<()> {
return ineligible_error("contains sink to non-cloud location");
}
},
plan => {
plan.get_expr(&mut expr_stack);

for expr in expr_stack.drain(..) {
for expr_node in expr.into_iter() {
match expr_node {
Expr::AnonymousFunction { .. } => {
stinodego marked this conversation as resolved.
Show resolved Hide resolved
return ineligible_error("contains anonymous function")
},
Expr::RenameAlias { .. } => {
return ineligible_error("contains custom name remapping")
},
_ => (),
}
}
}
},
_ => (),
}
}
Ok(())
Expand Down Expand Up @@ -101,47 +70,6 @@ impl DslPlan {
PythonScan { .. } => (),
}
}

fn get_expr<'a>(&'a self, scratch: &mut Vec<&'a Expr>) {
use DslPlan::*;
match self {
Filter { predicate, .. } => scratch.push(predicate),
Scan { predicate, .. } => {
if let Some(expr) = predicate {
scratch.push(expr)
}
},
DataFrameScan { filter, .. } => {
if let Some(expr) = filter {
scratch.push(expr)
}
},
Select { expr, .. } => scratch.extend(expr),
HStack { exprs, .. } => scratch.extend(exprs),
Sort { by_column, .. } => scratch.extend(by_column),
GroupBy { keys, aggs, .. } => {
scratch.extend(keys);
scratch.extend(aggs);
},
Join {
left_on, right_on, ..
} => {
scratch.extend(left_on);
scratch.extend(right_on);
},
Cache { .. }
| Distinct { .. }
| Slice { .. }
| MapFunction { .. }
| Union { .. }
| HConcat { .. }
| ExtContext { .. }
| Sink { .. }
| IR { .. } => (),
#[cfg(feature = "python")]
PythonScan { .. } => (),
}
}
}

pub struct DslPlanIter<'a> {
Expand Down
3 changes: 0 additions & 3 deletions crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ pub enum Expr {
Len,
/// Take the nth column in the `DataFrame`
Nth(i64),
// skipped fields must be last otherwise serde fails in pickle
#[cfg_attr(feature = "serde", serde(skip))]
RenameAlias {
function: SpecialEq<Arc<dyn RenameAliasFn>>,
expr: Arc<Expr>,
Expand All @@ -157,7 +155,6 @@ pub enum Expr {
/// function to apply
function: SpecialEq<Arc<dyn SeriesUdf>>,
/// output dtype of the function
#[cfg_attr(feature = "serde", serde(skip))]
output_type: GetOutput,
options: FunctionOptions,
},
Expand Down
100 changes: 92 additions & 8 deletions crates/polars-plan/src/dsl/expr_dyn_fn.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt::Formatter;
use std::ops::Deref;
use std::sync::Arc;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand All @@ -17,7 +18,7 @@ pub trait SeriesUdf: Send + Sync {
fn call_udf(&self, s: &mut [Series]) -> PolarsResult<Option<Series>>;

fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
polars_bail!(ComputeError: "serialize not supported for this 'opaque' function")
polars_bail!(ComputeError: "serialization not supported for this 'opaque' function")
}

// Needed for python functions. After they are deserialized we first check if they
Expand Down Expand Up @@ -46,30 +47,29 @@ impl Serialize for SpecialEq<Arc<dyn SeriesUdf>> {

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn SeriesUdf>> {
fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
{
use serde::de::Error;
#[cfg(feature = "python")]
{
use crate::dsl::python_udf::MAGIC_BYTE_MARK;
let buf = Vec::<u8>::deserialize(_deserializer)?;
let buf = Vec::<u8>::deserialize(deserializer)?;

if buf.starts_with(MAGIC_BYTE_MARK) {
if buf.starts_with(python_udf::MAGIC_BYTE_MARK) {
let udf = python_udf::PythonUdfExpression::try_deserialize(&buf)
.map_err(|e| D::Error::custom(format!("{e}")))?;
Ok(SpecialEq::new(udf))
} else {
Err(D::Error::custom(
"deserialize not supported for this 'opaque' function",
"deserialization not supported for this 'opaque' function",
))
}
}
#[cfg(not(feature = "python"))]
{
Err(D::Error::custom(
"deserialize not supported for this 'opaque' function",
"deserialization not supported for this 'opaque' function",
))
}
}
Expand Down Expand Up @@ -125,9 +125,16 @@ impl Default for SpecialEq<Arc<dyn BinaryUdfOutputField>> {

pub trait RenameAliasFn: Send + Sync {
fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr>;

fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
polars_bail!(ComputeError: "serialization not supported for this renaming function")
}
}

impl<F: Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + Send + Sync> RenameAliasFn for F {
impl<F> RenameAliasFn for F
where
F: Fn(&PlSmallStr) -> PolarsResult<PlSmallStr> + Send + Sync,
{
fn call(&self, name: &PlSmallStr) -> PolarsResult<PlSmallStr> {
self(name)
}
Expand Down Expand Up @@ -250,6 +257,10 @@ pub trait FunctionOutputField: Send + Sync {
cntxt: Context,
fields: &[Field],
) -> PolarsResult<Field>;

fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
polars_bail!(ComputeError: "serialization not supported for this output field")
}
}

pub type GetOutput = SpecialEq<Arc<dyn FunctionOutputField>>;
Expand Down Expand Up @@ -344,3 +355,76 @@ where
self(input_schema, cntxt, fields)
}
}

#[cfg(feature = "serde")]
impl Serialize for GetOutput {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::Error;
let mut buf = vec![];
self.0
.try_serialize(&mut buf)
.map_err(|e| S::Error::custom(format!("{e}")))?;
serializer.serialize_bytes(&buf)
}
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for GetOutput {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
{
use serde::de::Error;
#[cfg(feature = "python")]
{
let buf = Vec::<u8>::deserialize(deserializer)?;

if buf.starts_with(python_udf::MAGIC_BYTE_MARK) {
let get_output = python_udf::PythonGetOutput::try_deserialize(&buf)
.map_err(|e| D::Error::custom(format!("{e}")))?;
Ok(SpecialEq::new(get_output))
} else {
Err(D::Error::custom(
"deserialization not supported for this output field",
))
}
}
#[cfg(not(feature = "python"))]
{
Err(D::Error::custom(
"deserialization not supported for this output field",
))
}
}
}

#[cfg(feature = "serde")]
impl Serialize for SpecialEq<Arc<dyn RenameAliasFn>> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::Error;
let mut buf = vec![];
self.0
.try_serialize(&mut buf)
.map_err(|e| S::Error::custom(format!("{e}")))?;
serializer.serialize_bytes(&buf)
}
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn RenameAliasFn>> {
fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
{
use serde::de::Error;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we could add a PythonRenamAlias as well, I presume?

Copy link
Member Author

@stinodego stinodego Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - we'd have to do something similar to PythonUdfExpression where we store the lambda as a PyObject.

But it's not super trivial and name.map and name.map_fields are implemented differently... I'd rather pick these up separately as I don't think those are too important and it will require some minor refactoring I think.

Err(D::Error::custom(
"deserialization not supported for this renaming function",
))
}
}
69 changes: 56 additions & 13 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use polars_core::datatypes::{DataType, Field};
use polars_core::error::*;
use polars_core::frame::DataFrame;
use polars_core::prelude::Series;
use polars_core::schema::Schema;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedBytes;
use pyo3::types::PyBytes;
Expand All @@ -17,14 +18,14 @@ use super::expr_dyn_fn::*;
use crate::constants::MAP_LIST_NAME;
use crate::prelude::*;

// Will be overwritten on python polar start up.
// Will be overwritten on Python Polars start up.
pub static mut CALL_SERIES_UDF_PYTHON: Option<
fn(s: Series, lambda: &PyObject) -> PolarsResult<Series>,
> = None;
pub static mut CALL_DF_UDF_PYTHON: Option<
fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
> = None;
pub(super) const MAGIC_BYTE_MARK: &[u8] = "POLARS_PYTHON_UDF".as_bytes();
pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes();

#[derive(Clone, Debug)]
pub struct PythonFunction(pub PyObject);
Expand Down Expand Up @@ -141,7 +142,7 @@ impl PythonUdfExpression {
.unwrap();
let arg = (PyBytes::new_bound(py, remainder),);
let python_function = pickle.call1(arg).map_err(from_pyerr)?;
Ok(Arc::new(PythonUdfExpression::new(
Ok(Arc::new(Self::new(
python_function.into(),
output_type,
is_elementwise,
Expand Down Expand Up @@ -229,6 +230,54 @@ impl SeriesUdf for PythonUdfExpression {
}
}

/// Serializable version of [`GetOutput`] for Python UDFs.
pub struct PythonGetOutput {
return_dtype: Option<DataType>,
}

impl PythonGetOutput {
pub fn new(return_dtype: Option<DataType>) -> Self {
Self { return_dtype }
}

#[cfg(feature = "serde")]
pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn FunctionOutputField>> {
// Skip header.
debug_assert!(buf.starts_with(MAGIC_BYTE_MARK));
let buf = &buf[MAGIC_BYTE_MARK.len()..];

let mut reader = Cursor::new(buf);
let return_dtype: Option<DataType> =
ciborium::de::from_reader(&mut reader).map_err(map_err)?;

Ok(Arc::new(Self::new(return_dtype)) as Arc<dyn FunctionOutputField>)
}
}

impl FunctionOutputField for PythonGetOutput {
fn get_field(
&self,
_input_schema: &Schema,
_cntxt: Context,
fields: &[Field],
) -> PolarsResult<Field> {
// Take the name of first field, just like [`GetOutput::map_field`].
let name = fields[0].name();
let return_dtype = match self.return_dtype {
Some(ref dtype) => dtype.clone(),
None => DataType::Unknown(Default::default()),
};
Ok(Field::new(name.clone(), return_dtype))
}

#[cfg(feature = "serde")]
fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
buf.extend_from_slice(MAGIC_BYTE_MARK);
ciborium::ser::into_writer(&self.return_dtype, &mut *buf).unwrap();
Ok(())
}
}

impl Expr {
pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr {
let (collect_groups, name) = if agg_list {
Expand All @@ -241,16 +290,10 @@ impl Expr {

let returns_scalar = func.returns_scalar;
let return_dtype = func.output_type.clone();
let output_type = GetOutput::map_field(move |fld| {
Ok(match return_dtype {
Some(ref dt) => Field::new(fld.name().clone(), dt.clone()),
None => {
let mut fld = fld.clone();
fld.coerce(DataType::Unknown(Default::default()));
fld
},
})
});

let output_field = PythonGetOutput::new(return_dtype);
let output_type = SpecialEq::new(Arc::new(output_field) as Arc<dyn FunctionOutputField>);

let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT;
if returns_scalar {
flags |= FunctionFlags::RETURNS_SCALAR;
Expand Down
Loading