Skip to content

Commit

Permalink
feat(rust, python): add streamable udfs
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 1, 2023
1 parent 2946e07 commit ce9ce26
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 44 deletions.
29 changes: 29 additions & 0 deletions polars/polars-lazy/polars-pipe/src/executors/operators/function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use polars_core::error::PolarsResult;
use polars_plan::prelude::*;

use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext};

#[derive(Clone)]
pub struct FunctionOperator {
pub(crate) function: FunctionNode,
}

impl Operator for FunctionOperator {
fn execute(
&mut self,
_context: &PExecutionContext,
chunk: &DataChunk,
) -> PolarsResult<OperatorResult> {
Ok(OperatorResult::Finished(
chunk.with_data(self.function.evaluate(chunk.data.clone())?),
))
}

fn split(&self, _thread_no: usize) -> Box<dyn Operator> {
Box::new(self.clone())
}

fn fmt(&self) -> &str {
"function"
}
}
6 changes: 4 additions & 2 deletions polars/polars-lazy/polars-pipe/src/executors/operators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod dummy;
mod filter;
mod function;
mod placeholder;
mod projection;

pub(crate) use dummy::PlaceHolder;
pub(crate) use filter::*;
pub(crate) use function::*;
pub(crate) use placeholder::PlaceHolder;
pub(crate) use projection::*;
6 changes: 6 additions & 0 deletions polars/polars-lazy/polars-pipe/src/pipeline/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,12 @@ where
};
Box::new(op) as Box<dyn Operator>
}
MapFunction { function, .. } => {
let op = operators::FunctionOperator {
function: function.clone(),
};
Box::new(op) as Box<dyn Operator>
}

lp => {
panic!("operator {lp:?} not (yet) supported")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ pub(super) fn strptime(s: &Series, options: &StrpTimeOptions) -> PolarsResult<Se
.into(),
));
}
#[cfg(feature = "regex")]
#[cfg(feature = "timezones")]
(false, Some(fmt)) => TZ_AWARE_RE.is_match(fmt),
(false, _) => false,
};
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/polars-plan/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl AExpr {
use AExpr::*;
match self {
Function { options, .. } | AnonymousFunction { options, .. } => {
options.collect_groups == ApplyOptions::ApplyGroups
options.is_groups_sensitive()
}
Sort { .. }
| SortBy { .. }
Expand Down
11 changes: 11 additions & 0 deletions polars/polars-lazy/polars-plan/src/logical_plan/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ impl PartialEq for FunctionNode {
}

impl FunctionNode {
/// Whether this function can run on batches of data at a time.
pub fn is_streamable(&self) -> bool {
use FunctionNode::*;
match self {
Opaque { .. } | Rechunk | Pipeline { .. } => false,
#[cfg(feature = "merge_sorted")]
MergeSorted { .. } => false,
DropNulls { .. } | FastProjection { .. } | Unnest { .. } => true,
}
}

pub(crate) fn schema<'a>(
&self,
input_schema: &'a SchemaRef,
Expand Down
9 changes: 6 additions & 3 deletions polars/polars-lazy/polars-plan/src/logical_plan/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,12 @@ pub struct FunctionOptions {
}

impl FunctionOptions {
/// Whether this can simply applied elementwise
pub fn is_mappable(&self) -> bool {
!matches!(self.collect_groups, ApplyOptions::ApplyGroups)
/// Any function that is sensitive to the number of elements in a group
/// - Aggregations
/// - Sorts
/// - Counts
pub fn is_groups_sensitive(&self) -> bool {
matches!(self.collect_groups, ApplyOptions::ApplyGroups)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ impl PhysicalExpr for TernaryExpr {
Expr::Agg(_) => has_agg = true,
Expr::Function { options, .. }
| Expr::AnonymousFunction { options, .. }
if !options.is_mappable() =>
if options.is_groups_sensitive() =>
{
has_agg = true
}
Expand Down
105 changes: 69 additions & 36 deletions polars/polars-lazy/src/physical_plan/streaming/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ use crate::physical_plan::PhysicalExpr;

pub struct Wrap(Arc<dyn PhysicalExpr>);

type IsSink = bool;
// a rhs of a join will be replaced later
type IsRhsJoin = bool;

const IS_SINK: bool = true;
const IS_RHS_JOIN: bool = true;

impl PhysicalPipedExpr for Wrap {
fn evaluate(&self, chunk: &DataChunk, state: &dyn Any) -> PolarsResult<Series> {
let state = state.downcast_ref::<ExecutionState>().unwrap();
Expand Down Expand Up @@ -105,6 +112,30 @@ fn streamable_join(join_type: &JoinType) -> bool {
}
}

// The index of the pipeline tree we are building at this moment
// if we have a node we cannot do streaming, we have finished that pipeline tree
// and start a new one.
type CurrentIdx = usize;

fn process_non_streamable_node(
current_idx: &mut CurrentIdx,
state: &mut Branch,
stack: &mut Vec<(Node, Branch, CurrentIdx)>,
scratch: &mut Vec<Node>,
pipeline_trees: &mut Vec<Vec<Branch>>,
lp: &ALogicalPlan,
) {
if state.streamable {
*current_idx += 1;
pipeline_trees.push(vec![]);
}
state.streamable = false;
lp.copy_inputs(scratch);
while let Some(input) = scratch.pop() {
stack.push((input, Branch::default(), *current_idx))
}
}

pub(crate) fn insert_streaming_nodes(
root: Node,
lp_arena: &mut Arena<ALogicalPlan>,
Expand All @@ -120,10 +151,6 @@ pub(crate) fn insert_streaming_nodes(

let mut stack = Vec::with_capacity(16);

// The index of the pipeline tree we are building at this moment
// if we have a node we cannot do streaming, we have finished that pipeline tree
// and start a new one.
type CurrentIdx = usize;
stack.push((root, Branch::default(), 0 as CurrentIdx));

// A state holds a full pipeline until the breaker
Expand All @@ -148,17 +175,17 @@ pub(crate) fn insert_streaming_nodes(
match lp_arena.get(root) {
Selection { input, predicate } if is_streamable(*predicate, expr_arena) => {
state.streamable = true;
state.operators_sinks.push((false, false, root));
state.operators_sinks.push((!IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
}
HStack { input, exprs, .. } if all_streamable(exprs, expr_arena) => {
state.streamable = true;
state.operators_sinks.push((false, false, root));
state.operators_sinks.push((!IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
}
Slice { input, offset, .. } if *offset >= 0 => {
state.streamable = true;
state.operators_sinks.push((true, false, root));
state.operators_sinks.push((IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
}
FileSink { input, .. } => {
Expand All @@ -175,30 +202,39 @@ pub(crate) fn insert_streaming_nodes(
&& all_column(by_column, expr_arena) =>
{
state.streamable = true;
state.operators_sinks.push((true, false, root));
stack.push((*input, state, current_idx))
}
MapFunction {
input,
function: FunctionNode::FastProjection { .. },
} => {
state.streamable = true;
state.operators_sinks.push((false, false, root));
state.operators_sinks.push((IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
}
Projection { input, expr, .. } if all_streamable(expr, expr_arena) => {
state.streamable = true;
state.operators_sinks.push((false, false, root));
state.operators_sinks.push((!IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
}
// Rechunks are ignored
MapFunction {
input,
function: FunctionNode::Rechunk,
} => {
// we ignore a rechunk
state.streamable = true;
stack.push((*input, state, current_idx))
}
// Streamable functions will be converted
lp @ MapFunction { input, function } => {
if function.is_streamable() {
state.streamable = true;
state.operators_sinks.push((!IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
} else {
process_non_streamable_node(
&mut current_idx,
&mut state,
&mut stack,
scratch,
&mut pipeline_trees,
lp,
)
}
}
#[cfg(feature = "csv-file")]
CsvScan { .. } => {
if state.streamable {
Expand Down Expand Up @@ -245,7 +281,9 @@ pub(crate) fn insert_streaming_nodes(
// rhs is second, so that is first on the stack
let mut state_right = state;
state_right.join_count = 0;
state_right.operators_sinks.push((true, true, root));
state_right
.operators_sinks
.push((IS_SINK, IS_RHS_JOIN, root));
stack.push((input_right, state_right, current_idx));

// we want to traverse lhs first, so push it latest on the stack
Expand All @@ -255,7 +293,9 @@ pub(crate) fn insert_streaming_nodes(
join_count,
..Default::default()
};
state_left.operators_sinks.push((true, false, root));
state_left
.operators_sinks
.push((IS_SINK, !IS_RHS_JOIN, root));
stack.push((input_left, state_left, current_idx));
}
// add globbing patterns
Expand Down Expand Up @@ -309,23 +349,20 @@ pub(crate) fn insert_streaming_nodes(
.all(|dt| allowed_dtype(dt, string_cache))
{
state.streamable = true;
state.operators_sinks.push((true, false, root));
state.operators_sinks.push((IS_SINK, !IS_RHS_JOIN, root));
stack.push((*input, state, current_idx))
} else {
stack.push((*input, Branch::default(), current_idx))
}
}
lp => {
if state.streamable {
current_idx += 1;
pipeline_trees.push(vec![]);
}
state.streamable = false;
lp.copy_inputs(scratch);
while let Some(input) = scratch.pop() {
stack.push((input, Branch::default(), current_idx))
}
}
lp => process_non_streamable_node(
&mut current_idx,
&mut state,
&mut stack,
scratch,
&mut pipeline_trees,
lp,
),
}
}
let mut inserted = false;
Expand Down Expand Up @@ -448,10 +485,6 @@ pub(crate) fn insert_streaming_nodes(
Ok(inserted)
}

type IsSink = bool;
// a rhs of a join will be replaced later
type IsRhsJoin = bool;

#[derive(Default, Debug, Clone)]
struct Branch {
streamable: bool,
Expand Down

0 comments on commit ce9ce26

Please sign in to comment.