Skip to content

Commit

Permalink
g
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Aug 13, 2021
1 parent 3810653 commit fbea0ed
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 43 deletions.
13 changes: 7 additions & 6 deletions datafusion/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,11 @@ impl LogicalPlanBuilder {
/// This function errors under any of the following conditions:
/// * Two or more expressions have the same name
/// * An invalid expression is used (e.g. a `sort` expression)
pub fn project(&self, expr: impl IntoIterator<Item = Expr>) -> Result<Self> {
pub fn project(&self, expr: impl IntoIterator<Item = impl Into<Expr>>) -> Result<Self> {
let input_schema = self.plan.schema();
let mut projected_expr = vec![];
for e in expr {
let e = e.into();
match e {
Expr::Wildcard => {
projected_expr.extend(expand_wildcard(input_schema, &self.plan)?)
Expand All @@ -239,8 +240,8 @@ impl LogicalPlanBuilder {
}

/// Apply a filter
pub fn filter(&self, expr: Expr) -> Result<Self> {
let expr = normalize_col(expr, &self.plan)?;
pub fn filter(&self, expr: impl Into<Expr>) -> Result<Self> {
let expr = normalize_col(expr.into(), &self.plan)?;
Ok(Self::from(LogicalPlan::Filter {
predicate: expr,
input: Arc::new(self.plan.clone()),
Expand All @@ -256,7 +257,7 @@ impl LogicalPlanBuilder {
}

/// Apply a sort
pub fn sort(&self, exprs: impl IntoIterator<Item = Expr>) -> Result<Self> {
pub fn sort(&self, exprs: impl IntoIterator<Item = impl Into<Expr>>) -> Result<Self> {
Ok(Self::from(LogicalPlan::Sort {
expr: normalize_cols(exprs, &self.plan)?,
input: Arc::new(self.plan.clone()),
Expand Down Expand Up @@ -434,8 +435,8 @@ impl LogicalPlanBuilder {
/// value of the `group_expr`;
pub fn aggregate(
&self,
group_expr: impl IntoIterator<Item = Expr>,
aggr_expr: impl IntoIterator<Item = Expr>,
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
Expand Down
4 changes: 2 additions & 2 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1227,10 +1227,10 @@ fn normalize_col_with_schemas(
/// Recursively normalize all Column expressions in a list of expression trees
#[inline]
pub fn normalize_cols(
exprs: impl IntoIterator<Item = Expr>,
exprs: impl IntoIterator<Item = impl Into<Expr>>,
plan: &LogicalPlan,
) -> Result<Vec<Expr>> {
exprs.into_iter().map(|e| normalize_col(e, plan)).collect()
exprs.into_iter().map(|e| normalize_col(e.into(), plan)).collect()
}

/// Recursively 'unnormalize' (remove all qualifiers) from an
Expand Down
12 changes: 0 additions & 12 deletions datafusion/src/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use std::convert::TryFrom;
use std::sync::Arc;

use libc::uintptr_t;
use pyo3::exceptions::{PyException, PyNotImplementedError};
use pyo3::prelude::*;
use pyo3::types::PyList;

use crate::arrow::array::{make_array_from_raw, ArrayRef};
use crate::arrow::datatypes::{DataType, Field, Schema};
use crate::arrow::ffi;
use crate::arrow::ffi::FFI_ArrowSchema;
use crate::arrow::record_batch::RecordBatch;
//use crate::arrow::pyarrow::PyArrowConvert;
use crate::scalar::ScalarValue;

use crate::arrow::pyarrow::PyArrowConvert;
use crate::error::DataFusionError;

Expand Down
3 changes: 2 additions & 1 deletion datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::collections::HashSet;
use std::str::FromStr;
use std::sync::Arc;
use std::{convert::TryInto, vec};
use std::iter;

use crate::catalog::TableReference;
use crate::datasource::TableProvider;
Expand Down Expand Up @@ -766,7 +767,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

let plan = if select.distinct {
return LogicalPlanBuilder::from(plan)
.aggregate(select_exprs_post_aggr, vec![])?
.aggregate(select_exprs_post_aggr, iter::empty::<Expr>())?
.build();
} else {
plan
Expand Down
33 changes: 11 additions & 22 deletions python/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use std::convert::From;
use std::sync::{Arc, Mutex};

use pyo3::{prelude::*, types::PyTuple};
use pyo3::prelude::*;
use tokio::runtime::Runtime;

use datafusion::arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -47,12 +48,9 @@ impl PyDataFrame {
impl PyDataFrame {
/// Select `expressions` from the existing PyDataFrame.
#[args(args = "*")]
fn select(&self, args: &PyTuple) -> PyResult<Self> {
let expressions = expression::from_tuple(args)?;
fn select(&self, args: Vec<PyExpr>) -> PyResult<Self> {
let builder = LogicalPlanBuilder::from(self.plan.clone());
let builder =
errors::wrap(builder.project(expressions.into_iter().map(|e| e.expr)))?;
let plan = errors::wrap(builder.build())?;
let plan = builder.project(args)?.build()?;

Ok(PyDataFrame {
ctx_state: self.ctx_state.clone(),
Expand All @@ -63,8 +61,7 @@ impl PyDataFrame {
/// Filter according to the `predicate` expression
fn filter(&self, predicate: PyExpr) -> PyResult<Self> {
let builder = LogicalPlanBuilder::from(self.plan.clone());
let builder = errors::wrap(builder.filter(predicate.expr))?;
let plan = errors::wrap(builder.build())?;
let plan = builder.filter(predicate)?.build()?;

Ok(PyDataFrame {
ctx_state: self.ctx_state.clone(),
Expand All @@ -75,11 +72,7 @@ impl PyDataFrame {
/// Aggregates using expressions
fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyResult<Self> {
let builder = LogicalPlanBuilder::from(self.plan.clone());
let builder = errors::wrap(builder.aggregate(
group_by.into_iter().map(|e| e.expr),
aggs.into_iter().map(|e| e.expr),
))?;
let plan = errors::wrap(builder.build())?;
let plan = builder.aggregate(group_by, aggs)?.build()?;

Ok(PyDataFrame {
ctx_state: self.ctx_state.clone(),
Expand All @@ -89,10 +82,8 @@ impl PyDataFrame {

/// Sort by specified sorting expressions
fn sort(&self, exprs: Vec<PyExpr>) -> PyResult<Self> {
let exprs = exprs.into_iter().map(|e| e.expr);
let builder = LogicalPlanBuilder::from(self.plan.clone());
let builder = errors::wrap(builder.sort(exprs))?;
let plan = errors::wrap(builder.build())?;
let plan = builder.sort(exprs)?.build()?;
Ok(PyDataFrame {
ctx_state: self.ctx_state.clone(),
plan,
Expand All @@ -102,8 +93,7 @@ impl PyDataFrame {
/// Limits the plan to return at most `count` rows
fn limit(&self, count: usize) -> PyResult<Self> {
let builder = LogicalPlanBuilder::from(self.plan.clone());
let builder = errors::wrap(builder.limit(count))?;
let plan = errors::wrap(builder.build())?;
let plan = builder.limit(count)?.build()?;

Ok(PyDataFrame {
ctx_state: self.ctx_state.clone(),
Expand Down Expand Up @@ -148,10 +138,9 @@ impl PyDataFrame {
}
};

let builder =
errors::wrap(builder.join_using(&right.plan, join_type, on.clone()))?;

let plan = errors::wrap(builder.build())?;
let plan = builder
.join_using(&right.plan, join_type, on.clone())?
.build()?;

Ok(PyDataFrame {
ctx_state: self.ctx_state.clone(),
Expand Down
8 changes: 8 additions & 0 deletions python/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
use pyo3::{
basic::CompareOp, prelude::*, types::PyTuple, PyNumberProtocol, PyObjectProtocol,
};
use std::convert::From;
use std::vec::Vec;

use datafusion::logical_plan::Expr;
use datafusion::physical_plan::{udaf::AggregateUDF, udf::ScalarUDF};
Expand All @@ -29,6 +31,12 @@ pub(crate) struct PyExpr {
pub(crate) expr: Expr,
}

impl From<PyExpr> for Expr {
fn from(expr: PyExpr) -> Expr {
expr.expr
}
}

/// converts a tuple of expressions into a vector of Expressions
pub(crate) fn from_tuple(value: &PyTuple) -> PyResult<Vec<PyExpr>> {
value
Expand Down

0 comments on commit fbea0ed

Please sign in to comment.