Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust, python): add streamable udfs #6614

Merged
merged 1 commit into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions polars/polars-lazy/polars-pipe/src/executors/operators/function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use polars_core::error::PolarsResult;
use polars_plan::prelude::*;

use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext};

#[derive(Clone)]
pub struct FunctionOperator {
pub(crate) function: FunctionNode,
}

impl Operator for FunctionOperator {
fn execute(
&mut self,
_context: &PExecutionContext,
chunk: &DataChunk,
) -> PolarsResult<OperatorResult> {
Ok(OperatorResult::Finished(
chunk.with_data(self.function.evaluate(chunk.data.clone())?),
))
}

fn split(&self, _thread_no: usize) -> Box<dyn Operator> {
Box::new(self.clone())
}

fn fmt(&self) -> &str {
"function"
}
}
6 changes: 4 additions & 2 deletions polars/polars-lazy/polars-pipe/src/executors/operators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod dummy;
mod filter;
mod function;
mod placeholder;
mod projection;

pub(crate) use dummy::PlaceHolder;
pub(crate) use filter::*;
pub(crate) use function::*;
pub(crate) use placeholder::PlaceHolder;
pub(crate) use projection::*;
6 changes: 6 additions & 0 deletions polars/polars-lazy/polars-pipe/src/pipeline/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,12 @@ where
};
Box::new(op) as Box<dyn Operator>
}
MapFunction { function, .. } => {
let op = operators::FunctionOperator {
function: function.clone(),
};
Box::new(op) as Box<dyn Operator>
}

lp => {
panic!("operator {lp:?} not (yet) supported")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ pub(super) fn strptime(s: &Series, options: &StrpTimeOptions) -> PolarsResult<Se
.into(),
));
}
#[cfg(feature = "regex")]
#[cfg(feature = "timezones")]
(false, Some(fmt)) => TZ_AWARE_RE.is_match(fmt),
(false, _) => false,
};
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/polars-plan/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl AExpr {
use AExpr::*;
match self {
Function { options, .. } | AnonymousFunction { options, .. } => {
options.collect_groups == ApplyOptions::ApplyGroups
options.is_groups_sensitive()
}
Sort { .. }
| SortBy { .. }
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/polars-plan/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ impl LogicalPlanBuilder {
schema,
predicate_pd: optimizations.predicate_pushdown,
projection_pd: optimizations.projection_pushdown,
streamable: optimizations.streaming,
fmt_str: name,
},
}
Expand Down
13 changes: 13 additions & 0 deletions polars/polars-lazy/polars-plan/src/logical_plan/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub enum FunctionNode {
predicate_pd: bool,
/// allow projection pushdown optimizations
projection_pd: bool,
streamable: bool,
// used for formatting
#[cfg_attr(feature = "serde", serde(skip))]
fmt_str: &'static str,
Expand Down Expand Up @@ -70,6 +71,18 @@ impl PartialEq for FunctionNode {
}

impl FunctionNode {
/// Whether this function can run on batches of data at a time.
pub fn is_streamable(&self) -> bool {
use FunctionNode::*;
match self {
Rechunk | Pipeline { .. } => false,
#[cfg(feature = "merge_sorted")]
MergeSorted { .. } => false,
DropNulls { .. } | FastProjection { .. } | Unnest { .. } => true,
Opaque { streamable, .. } => *streamable,
}
}

pub(crate) fn schema<'a>(
&self,
input_schema: &'a SchemaRef,
Expand Down
9 changes: 6 additions & 3 deletions polars/polars-lazy/polars-plan/src/logical_plan/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,12 @@ pub struct FunctionOptions {
}

impl FunctionOptions {
/// Whether this can simply applied elementwise
pub fn is_mappable(&self) -> bool {
!matches!(self.collect_groups, ApplyOptions::ApplyGroups)
/// Any function that is sensitive to the number of elements in a group
/// - Aggregations
/// - Sorts
/// - Counts
pub fn is_groups_sensitive(&self) -> bool {
matches!(self.collect_groups, ApplyOptions::ApplyGroups)
}
}

Expand Down
2 changes: 2 additions & 0 deletions polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ impl LazyFrame {
AllowedOptimizations {
projection_pushdown: false,
predicate_pushdown: false,
streaming: true,
..Default::default()
},
Some(Arc::new(udf_schema)),
Expand Down Expand Up @@ -1198,6 +1199,7 @@ impl LazyFrame {
AllowedOptimizations {
slice_pushdown: false,
predicate_pushdown: false,
streaming: false,
..Default::default()
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ impl PhysicalExpr for TernaryExpr {
Expr::Agg(_) => has_agg = true,
Expr::Function { options, .. }
| Expr::AnonymousFunction { options, .. }
if !options.is_mappable() =>
if options.is_groups_sensitive() =>
{
has_agg = true
}
Expand Down
105 changes: 69 additions & 36 deletions polars/polars-lazy/src/physical_plan/streaming/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ use crate::physical_plan::PhysicalExpr;

pub struct Wrap(Arc<dyn PhysicalExpr>);

type IsSink = bool;
// a rhs of a join will be replaced later
type IsRhsJoin = bool;

const IS_SINK: bool = true;
const IS_RHS_JOIN: bool = true;

impl PhysicalPipedExpr for Wrap {
fn evaluate(&self, chunk: &DataChunk, state: &dyn Any) -> PolarsResult<Series> {
let state = state.downcast_ref::<ExecutionState>().unwrap();
Expand Down Expand Up @@ -105,6 +112,30 @@ fn streamable_join(join_type: &JoinType) -> bool {
}
}

// The index of the pipeline tree we are building at this moment
// if we have a node we cannot do streaming, we have finished that pipeline tree
// and start a new one.
type CurrentIdx = usize;

fn process_non_streamable_node(
current_idx: &mut CurrentIdx,
state: &mut Branch,
stack: &mut Vec<(Node, Branch, CurrentIdx)>,
scratch: &mut Vec<Node>,
pipeline_trees: &mut Vec<Vec<Branch>>,
lp: &ALogicalPlan,
) {
if state.streamable {
*current_idx += 1;
pipeline_trees.push(vec![]);
}
state.streamable = false;
lp.copy_inputs(scratch);
while let Some(input) = scratch.pop() {
stack.push((input, Branch::default(), *current_idx))
}
}

pub(crate) fn insert_streaming_nodes(
root: Node,
lp_arena: &mut Arena<ALogicalPlan>,
Expand All @@ -120,10 +151,6 @@ pub(crate) fn insert_streaming_nodes(

let mut stack = Vec::with_capacity(16);

// The index of the pipeline tree we are building at this moment
// if we have a node we cannot do streaming, we have finished that pipeline tree
// and start a new one.
type CurrentIdx = usize;
stack.push((root, Branch::default(), 0 as CurrentIdx));

// A state holds a full pipeline until the breaker
Expand All @@ -148,17 +175,17 @@ pub(crate) fn insert_streaming_nodes(
match lp_arena.get(root) {
Selection { input, predicate } if is_streamable(*predicate, expr_arena) => {
state.streamable = true;
state.operators_sinks.push((false, false, root));
state.operators_sinks.push((!IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
}
HStack { input, exprs, .. } if all_streamable(exprs, expr_arena) => {
state.streamable = true;
state.operators_sinks.push((false, false, root));
state.operators_sinks.push((!IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
}
Slice { input, offset, .. } if *offset >= 0 => {
state.streamable = true;
state.operators_sinks.push((true, false, root));
state.operators_sinks.push((IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
}
FileSink { input, .. } => {
Expand All @@ -175,30 +202,39 @@ pub(crate) fn insert_streaming_nodes(
&& all_column(by_column, expr_arena) =>
{
state.streamable = true;
state.operators_sinks.push((true, false, root));
stack.push((*input, state, current_idx))
}
MapFunction {
input,
function: FunctionNode::FastProjection { .. },
} => {
state.streamable = true;
state.operators_sinks.push((false, false, root));
state.operators_sinks.push((IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
}
Projection { input, expr, .. } if all_streamable(expr, expr_arena) => {
state.streamable = true;
state.operators_sinks.push((false, false, root));
state.operators_sinks.push((!IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
}
// Rechunks are ignored
MapFunction {
input,
function: FunctionNode::Rechunk,
} => {
// we ignore a rechunk
state.streamable = true;
stack.push((*input, state, current_idx))
}
// Streamable functions will be converted
lp @ MapFunction { input, function } => {
if function.is_streamable() {
state.streamable = true;
state.operators_sinks.push((!IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
} else {
process_non_streamable_node(
&mut current_idx,
&mut state,
&mut stack,
scratch,
&mut pipeline_trees,
lp,
)
}
}
#[cfg(feature = "csv-file")]
CsvScan { .. } => {
if state.streamable {
Expand Down Expand Up @@ -245,7 +281,9 @@ pub(crate) fn insert_streaming_nodes(
// rhs is second, so that is first on the stack
let mut state_right = state;
state_right.join_count = 0;
state_right.operators_sinks.push((true, true, root));
state_right
.operators_sinks
.push((IS_SINK, IS_RHS_JOIN, root));
stack.push((input_right, state_right, current_idx));

// we want to traverse lhs first, so push it latest on the stack
Expand All @@ -255,7 +293,9 @@ pub(crate) fn insert_streaming_nodes(
join_count,
..Default::default()
};
state_left.operators_sinks.push((true, false, root));
state_left
.operators_sinks
.push((IS_SINK, !IS_RHS_JOIN, root));
stack.push((input_left, state_left, current_idx));
}
// add globbing patterns
Expand Down Expand Up @@ -309,23 +349,20 @@ pub(crate) fn insert_streaming_nodes(
.all(|dt| allowed_dtype(dt, string_cache))
{
state.streamable = true;
state.operators_sinks.push((true, false, root));
state.operators_sinks.push((IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
} else {
stack.push((*input, Branch::default(), current_idx))
}
}
lp => {
if state.streamable {
current_idx += 1;
pipeline_trees.push(vec![]);
}
state.streamable = false;
lp.copy_inputs(scratch);
while let Some(input) = scratch.pop() {
stack.push((input, Branch::default(), current_idx))
}
}
lp => process_non_streamable_node(
&mut current_idx,
&mut state,
&mut stack,
scratch,
&mut pipeline_trees,
lp,
),
}
}
let mut inserted = false;
Expand Down Expand Up @@ -448,10 +485,6 @@ pub(crate) fn insert_streaming_nodes(
Ok(inserted)
}

type IsSink = bool;
// a rhs of a join will be replaced later
type IsRhsJoin = bool;

#[derive(Default, Debug, Clone)]
struct Branch {
streamable: bool,
Expand Down
10 changes: 8 additions & 2 deletions py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3794,6 +3794,7 @@ def map(
no_optimizations: bool = False,
schema: None | SchemaDict = None,
validate_output_schema: bool = True,
streamable: bool = False,
) -> LDF:
"""
Apply a custom function.
Expand All @@ -3820,6 +3821,10 @@ def map(
the output schema of this function will be checked with the expected schema.
Setting this to ``False`` will not do this check, but may lead to hard to
debug bugs.
streamable
Whether the function that is given is eligible ot running in the streaming
engine. That means that the function must produce the same result if it
is exectuted on batches as it would when executed on the full dataset.

Warnings
--------
Expand Down Expand Up @@ -3856,8 +3861,9 @@ def map(
predicate_pushdown,
projection_pushdown,
slice_pushdown,
schema,
validate_output_schema,
streamable=streamable,
schema=schema,
validate_output=validate_output_schema,
)
)

Expand Down
Loading