Skip to content

Commit

Permalink
Rework the python bindings using conversion traits from arrow-rs (#873)
Browse files Browse the repository at this point in the history
* Reuse wait_for_future from dataframe.rs

* Improve module organization

* Reorganize tests

* Support Binary and LargeBinary arrays for ScalarValue::try_from_array

* Define column and literal functions in python

* Apply isort

* Add test for importing from datafusion.functions

* Rename internals to _internal

* Make classes inheritable; add tests for imports; set module

* Remove PyVolatility

* Move ScalarUdf to udf.rs

* Factor out PyScalarUDF and PyAggregateUDF

* Refactor UDF and UDAF construction

* Set public as the default database for the catalog
  • Loading branch information
kszucs authored Nov 2, 2021
1 parent 91b5469 commit 1c351ec
Show file tree
Hide file tree
Showing 45 changed files with 1,498 additions and 1,506 deletions.
21 changes: 11 additions & 10 deletions ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,15 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
}
LogicalPlanType::Selection(selection) => {
let input: LogicalPlan = convert_box_required!(selection.input)?;
let expr: Expr = selection
.expr
.as_ref()
.ok_or_else(|| {
BallistaError::General("expression required".to_string())
})?
.try_into()?;
LogicalPlanBuilder::from(input)
.filter(
selection
.expr
.as_ref()
.expect("expression required")
.try_into()?,
)?
.filter(expr)?
.build()
.map_err(|e| e.into())
}
Expand All @@ -123,7 +124,7 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.window_expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, _>>()?;
.collect::<Result<Vec<Expr>, _>>()?;
LogicalPlanBuilder::from(input)
.window(window_expr)?
.build()
Expand All @@ -135,12 +136,12 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.group_expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, _>>()?;
.collect::<Result<Vec<Expr>, _>>()?;
let aggr_expr = aggregate
.aggr_expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, _>>()?;
.collect::<Result<Vec<Expr>, _>>()?;
LogicalPlanBuilder::from(input)
.aggregate(group_expr, aggr_expr)?
.build()
Expand Down
2 changes: 1 addition & 1 deletion datafusion-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ clap = "2.33"
rustyline = "9.0"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
datafusion = { path = "../datafusion", version = "5.1.0" }
arrow = { version = "6.0.0" }
arrow = { version = "6.0.0" }
ballista = { path = "../ballista/rust/client", version = "0.6.0" }
2 changes: 2 additions & 0 deletions datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ simd = ["arrow/simd"]
crypto_expressions = ["md-5", "sha2", "blake2", "blake3"]
regex_expressions = ["regex"]
unicode_expressions = ["unicode-segmentation"]
pyarrow = ["pyo3", "arrow/pyarrow"]
# Used for testing ONLY: causes all values to hash to the same value (test for collisions)
force_hash_collisions = []
# Used to enable the avro format
Expand Down Expand Up @@ -75,6 +76,7 @@ smallvec = { version = "1.6", features = ["union"] }
rand = "0.8"
avro-rs = { version = "0.13", features = ["snappy"], optional = true }
num-traits = { version = "0.2", optional = true }
pyo3 = { version = "0.14", optional = true }

[dev-dependencies]
criterion = "0.3"
Expand Down
8 changes: 8 additions & 0 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,14 @@ impl ExecutionContext {
)))
}

/// Creates an empty DataFrame.
pub fn read_empty(&self) -> Result<Arc<dyn DataFrame>> {
Ok(Arc::new(DataFrameImpl::new(
self.state.clone(),
&LogicalPlanBuilder::empty(true).build()?,
)))
}

/// Creates a DataFrame for reading a CSV data source.
pub async fn read_csv(
&mut self,
Expand Down
4 changes: 4 additions & 0 deletions datafusion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ pub use arrow;
pub use parquet;

pub(crate) mod field_util;

#[cfg(feature = "pyarrow")]
mod pyarrow;

#[cfg(test)]
pub mod test;
pub mod test_util;
Expand Down
30 changes: 19 additions & 11 deletions datafusion/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use arrow::{
record_batch::RecordBatch,
};
use std::convert::TryFrom;
use std::iter;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
Expand Down Expand Up @@ -426,14 +427,17 @@ impl LogicalPlanBuilder {
Ok(plan)
}
/// Apply a projection without alias.
pub fn project(&self, expr: impl IntoIterator<Item = Expr>) -> Result<Self> {
pub fn project(
&self,
expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
self.project_with_alias(expr, None)
}

/// Apply a projection with alias
pub fn project_with_alias(
&self,
expr: impl IntoIterator<Item = Expr>,
expr: impl IntoIterator<Item = impl Into<Expr>>,
alias: Option<String>,
) -> Result<Self> {
Ok(Self::from(project_with_alias(
Expand All @@ -444,8 +448,8 @@ impl LogicalPlanBuilder {
}

/// Apply a filter
pub fn filter(&self, expr: Expr) -> Result<Self> {
let expr = normalize_col(expr, &self.plan)?;
pub fn filter(&self, expr: impl Into<Expr>) -> Result<Self> {
let expr = normalize_col(expr.into(), &self.plan)?;
Ok(Self::from(LogicalPlan::Filter {
predicate: expr,
input: Arc::new(self.plan.clone()),
Expand All @@ -461,7 +465,7 @@ impl LogicalPlanBuilder {
}

/// Apply a sort
pub fn sort(&self, exprs: impl IntoIterator<Item = Expr>) -> Result<Self> {
pub fn sort(&self, exprs: impl IntoIterator<Item = impl Into<Expr>>) -> Result<Self> {
Ok(Self::from(LogicalPlan::Sort {
expr: normalize_cols(exprs, &self.plan)?,
input: Arc::new(self.plan.clone()),
Expand All @@ -477,7 +481,7 @@ impl LogicalPlanBuilder {
pub fn distinct(&self) -> Result<Self> {
let projection_expr = expand_wildcard(self.plan.schema(), &self.plan)?;
let plan = LogicalPlanBuilder::from(self.plan.clone())
.aggregate(projection_expr, vec![])?
.aggregate(projection_expr, iter::empty::<Expr>())?
.build()?;
Self::from(plan).project(vec![Expr::Wildcard])
}
Expand Down Expand Up @@ -629,8 +633,11 @@ impl LogicalPlanBuilder {
}

/// Apply a window functions to extend the schema
pub fn window(&self, window_expr: impl IntoIterator<Item = Expr>) -> Result<Self> {
let window_expr = window_expr.into_iter().collect::<Vec<Expr>>();
pub fn window(
&self,
window_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
let window_expr = normalize_cols(window_expr, &self.plan)?;
let all_expr = window_expr.iter();
validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?;
let mut window_fields: Vec<DFField> =
Expand All @@ -648,8 +655,8 @@ impl LogicalPlanBuilder {
/// value of the `group_expr`;
pub fn aggregate(
&self,
group_expr: impl IntoIterator<Item = Expr>,
aggr_expr: impl IntoIterator<Item = Expr>,
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
Expand Down Expand Up @@ -796,12 +803,13 @@ pub fn union_with_alias(
/// * An invalid expression is used (e.g. a `sort` expression)
pub fn project_with_alias(
plan: LogicalPlan,
expr: impl IntoIterator<Item = Expr>,
expr: impl IntoIterator<Item = impl Into<Expr>>,
alias: Option<String>,
) -> Result<LogicalPlan> {
let input_schema = plan.schema();
let mut projected_expr = vec![];
for e in expr {
let e = e.into();
match e {
Expr::Wildcard => {
projected_expr.extend(expand_wildcard(input_schema, &plan)?)
Expand Down
12 changes: 10 additions & 2 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,9 @@ impl Expr {
/// This function errors when it is impossible to cast the
/// expression to the target [arrow::datatypes::DataType].
pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result<Expr> {
// TODO(kszucs): most of the operations do not validate the type correctness
// like all of the binary expressions below. Perhaps Expr should track the
// type of the expression?
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
Ok(self)
Expand Down Expand Up @@ -1305,10 +1308,13 @@ fn normalize_col_with_schemas(
/// Recursively normalize all Column expressions in a list of expression trees
#[inline]
pub fn normalize_cols(
exprs: impl IntoIterator<Item = Expr>,
exprs: impl IntoIterator<Item = impl Into<Expr>>,
plan: &LogicalPlan,
) -> Result<Vec<Expr>> {
exprs.into_iter().map(|e| normalize_col(e, plan)).collect()
exprs
.into_iter()
.map(|e| normalize_col(e.into(), plan))
.collect()
}

/// Recursively 'unnormalize' (remove all qualifiers) from an
Expand Down Expand Up @@ -1544,6 +1550,8 @@ pub fn approx_distinct(expr: Expr) -> Expr {
}
}

// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many
// varying arity functions
/// Create an convenience function representing a unary scalar function
macro_rules! unary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {
Expand Down
5 changes: 3 additions & 2 deletions datafusion/src/optimizer/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ mod test {
avg, binary_expr, col, lit, sum, LogicalPlanBuilder, Operator,
};
use crate::test::*;
use std::iter;

fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let optimizer = CommonSubexprEliminate {};
Expand Down Expand Up @@ -688,7 +689,7 @@ mod test {

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![],
iter::empty::<Expr>(),
vec![
sum(binary_expr(
col("a"),
Expand Down Expand Up @@ -723,7 +724,7 @@ mod test {

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![],
iter::empty::<Expr>(),
vec![
binary_expr(lit(1), Operator::Plus, avg(col("a"))),
binary_expr(lit(1), Operator::Minus, avg(col("a"))),
Expand Down
4 changes: 2 additions & 2 deletions datafusion/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ mod tests {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![], vec![max(col("b"))])?
.aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\
Expand Down Expand Up @@ -508,7 +508,7 @@ mod tests {

let plan = LogicalPlanBuilder::from(table_scan)
.filter(col("c"))?
.aggregate(vec![], vec![max(col("b"))])?
.aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\
Expand Down
67 changes: 67 additions & 0 deletions datafusion/src/pyarrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use pyo3::exceptions::{PyException, PyNotImplementedError};
use pyo3::prelude::*;
use pyo3::types::PyList;
use pyo3::PyNativeType;

use crate::arrow::array::ArrayData;
use crate::arrow::pyarrow::PyArrowConvert;
use crate::error::DataFusionError;
use crate::scalar::ScalarValue;

impl From<DataFusionError> for PyErr {
fn from(err: DataFusionError) -> PyErr {
PyException::new_err(err.to_string())
}
}

impl PyArrowConvert for ScalarValue {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
let py = value.py();
let typ = value.getattr("type")?;
let val = value.call_method0("as_py")?;

// construct pyarrow array from the python value and pyarrow type
let factory = py.import("pyarrow")?.getattr("array")?;
let args = PyList::new(py, &[val]);
let array = factory.call1((args, typ))?;

// convert the pyarrow array to rust array using C data interface
let array = array.extract::<ArrayData>()?;
let scalar = ScalarValue::try_from_array(&array.into(), 0)?;

Ok(scalar)
}

fn to_pyarrow(&self, _py: Python) -> PyResult<PyObject> {
Err(PyNotImplementedError::new_err("Not implemented"))
}
}

impl<'source> FromPyObject<'source> for ScalarValue {
fn extract(value: &'source PyAny) -> PyResult<Self> {
Self::from_pyarrow(value)
}
}

impl<'a> IntoPy<PyObject> for ScalarValue {
fn into_py(self, py: Python) -> PyObject {
self.to_pyarrow(py).unwrap()
}
}
4 changes: 4 additions & 0 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,10 @@ impl ScalarValue {
DataType::Int32 => typed_cast!(array, index, Int32Array, Int32),
DataType::Int16 => typed_cast!(array, index, Int16Array, Int16),
DataType::Int8 => typed_cast!(array, index, Int8Array, Int8),
DataType::Binary => typed_cast!(array, index, BinaryArray, Binary),
DataType::LargeBinary => {
typed_cast!(array, index, LargeBinaryArray, LargeBinary)
}
DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8),
DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8),
DataType::List(nested_type) => {
Expand Down
3 changes: 2 additions & 1 deletion datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! SQL Query Planner (produces logical plan from SQL AST)
use std::collections::HashSet;
use std::iter;
use std::str::FromStr;
use std::sync::Arc;
use std::{convert::TryInto, vec};
Expand Down Expand Up @@ -822,7 +823,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

let plan = if select.distinct {
return LogicalPlanBuilder::from(plan)
.aggregate(select_exprs_post_aggr, vec![])?
.aggregate(select_exprs_post_aggr, iter::empty::<Expr>())?
.build();
} else {
plan
Expand Down
Loading

0 comments on commit 1c351ec

Please sign in to comment.