diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 259fcb3482a7..66bec782e609 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -106,14 +106,15 @@ impl TryInto for &protobuf::LogicalPlanNode { } LogicalPlanType::Selection(selection) => { let input: LogicalPlan = convert_box_required!(selection.input)?; + let expr: Expr = selection + .expr + .as_ref() + .ok_or_else(|| { + BallistaError::General("expression required".to_string()) + })? + .try_into()?; LogicalPlanBuilder::from(input) - .filter( - selection - .expr - .as_ref() - .expect("expression required") - .try_into()?, - )? + .filter(expr)? .build() .map_err(|e| e.into()) } @@ -123,7 +124,7 @@ impl TryInto for &protobuf::LogicalPlanNode { .window_expr .iter() .map(|expr| expr.try_into()) - .collect::, _>>()?; + .collect::, _>>()?; LogicalPlanBuilder::from(input) .window(window_expr)? .build() @@ -135,12 +136,12 @@ impl TryInto for &protobuf::LogicalPlanNode { .group_expr .iter() .map(|expr| expr.try_into()) - .collect::, _>>()?; + .collect::, _>>()?; let aggr_expr = aggregate .aggr_expr .iter() .map(|expr| expr.try_into()) - .collect::, _>>()?; + .collect::, _>>()?; LogicalPlanBuilder::from(input) .aggregate(group_expr, aggr_expr)? .build() diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index b424f498ac5f..360e873c6ed7 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -31,5 +31,5 @@ clap = "2.33" rustyline = "9.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } datafusion = { path = "../datafusion", version = "5.1.0" } -arrow = { version = "6.0.0" } +arrow = { version = "6.0.0" } ballista = { path = "../ballista/rust/client", version = "0.6.0" } diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 793262a031bd..8aac711facc7 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -43,6 +43,7 @@ simd = ["arrow/simd"] crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] regex_expressions = ["regex"] unicode_expressions = ["unicode-segmentation"] +pyarrow = ["pyo3", "arrow/pyarrow"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] # Used to enable the avro format @@ -75,6 +76,7 @@ smallvec = { version = "1.6", features = ["union"] } rand = "0.8" avro-rs = { version = "0.13", features = ["snappy"], optional = true } num-traits = { version = "0.2", optional = true } +pyo3 = { version = "0.14", optional = true } [dev-dependencies] criterion = "0.3" diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 9be5038f47c9..81fd78518f0f 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -329,6 +329,14 @@ impl ExecutionContext { ))) } + /// Creates an empty DataFrame. + pub fn read_empty(&self) -> Result> { + Ok(Arc::new(DataFrameImpl::new( + self.state.clone(), + &LogicalPlanBuilder::empty(true).build()?, + ))) + } + /// Creates a DataFrame for reading a CSV data source. pub async fn read_csv( &mut self, diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 2159864d10fd..4f4cd664fd41 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -232,6 +232,10 @@ pub use arrow; pub use parquet; pub(crate) mod field_util; + +#[cfg(feature = "pyarrow")] +mod pyarrow; + #[cfg(test)] pub mod test; pub mod test_util; diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 693bf78fbe0e..a9d814f66eb0 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -33,6 +33,7 @@ use arrow::{ record_batch::RecordBatch, }; use std::convert::TryFrom; +use std::iter; use std::{ collections::{HashMap, HashSet}, sync::Arc, @@ -426,14 +427,17 @@ impl LogicalPlanBuilder { Ok(plan) } /// Apply a projection without alias. - pub fn project(&self, expr: impl IntoIterator) -> Result { + pub fn project( + &self, + expr: impl IntoIterator>, + ) -> Result { self.project_with_alias(expr, None) } /// Apply a projection with alias pub fn project_with_alias( &self, - expr: impl IntoIterator, + expr: impl IntoIterator>, alias: Option, ) -> Result { Ok(Self::from(project_with_alias( @@ -444,8 +448,8 @@ impl LogicalPlanBuilder { } /// Apply a filter - pub fn filter(&self, expr: Expr) -> Result { - let expr = normalize_col(expr, &self.plan)?; + pub fn filter(&self, expr: impl Into) -> Result { + let expr = normalize_col(expr.into(), &self.plan)?; Ok(Self::from(LogicalPlan::Filter { predicate: expr, input: Arc::new(self.plan.clone()), @@ -461,7 +465,7 @@ impl LogicalPlanBuilder { } /// Apply a sort - pub fn sort(&self, exprs: impl IntoIterator) -> Result { + pub fn sort(&self, exprs: impl IntoIterator>) -> Result { Ok(Self::from(LogicalPlan::Sort { expr: normalize_cols(exprs, &self.plan)?, input: Arc::new(self.plan.clone()), @@ -477,7 +481,7 @@ impl LogicalPlanBuilder { pub fn distinct(&self) -> Result { let projection_expr = expand_wildcard(self.plan.schema(), &self.plan)?; let plan = LogicalPlanBuilder::from(self.plan.clone()) - .aggregate(projection_expr, vec![])? + .aggregate(projection_expr, iter::empty::())? .build()?; Self::from(plan).project(vec![Expr::Wildcard]) } @@ -629,8 +633,11 @@ impl LogicalPlanBuilder { } /// Apply a window functions to extend the schema - pub fn window(&self, window_expr: impl IntoIterator) -> Result { - let window_expr = window_expr.into_iter().collect::>(); + pub fn window( + &self, + window_expr: impl IntoIterator>, + ) -> Result { + let window_expr = normalize_cols(window_expr, &self.plan)?; let all_expr = window_expr.iter(); validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?; let mut window_fields: Vec = @@ -648,8 +655,8 @@ impl LogicalPlanBuilder { /// value of the `group_expr`; pub fn aggregate( &self, - group_expr: impl IntoIterator, - aggr_expr: impl IntoIterator, + group_expr: impl IntoIterator>, + aggr_expr: impl IntoIterator>, ) -> Result { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; @@ -796,12 +803,13 @@ pub fn union_with_alias( /// * An invalid expression is used (e.g. a `sort` expression) pub fn project_with_alias( plan: LogicalPlan, - expr: impl IntoIterator, + expr: impl IntoIterator>, alias: Option, ) -> Result { let input_schema = plan.schema(); let mut projected_expr = vec![]; for e in expr { + let e = e.into(); match e { Expr::Wildcard => { projected_expr.extend(expand_wildcard(input_schema, &plan)?) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 499a8c720dba..19e6fe36c7d6 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -540,6 +540,9 @@ impl Expr { /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result { + // TODO(kszucs): most of the operations do not validate the type correctness + // like all of the binary expressions below. Perhaps Expr should track the + // type of the expression? let this_type = self.get_type(schema)?; if this_type == *cast_to_type { Ok(self) @@ -1305,10 +1308,13 @@ fn normalize_col_with_schemas( /// Recursively normalize all Column expressions in a list of expression trees #[inline] pub fn normalize_cols( - exprs: impl IntoIterator, + exprs: impl IntoIterator>, plan: &LogicalPlan, ) -> Result> { - exprs.into_iter().map(|e| normalize_col(e, plan)).collect() + exprs + .into_iter() + .map(|e| normalize_col(e.into(), plan)) + .collect() } /// Recursively 'unnormalize' (remove all qualifiers) from an @@ -1544,6 +1550,8 @@ pub fn approx_distinct(expr: Expr) -> Expr { } } +// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many +// varying arity functions /// Create an convenience function representing a unary scalar function macro_rules! unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => { diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs b/datafusion/src/optimizer/common_subexpr_eliminate.rs index ea60286b902f..0e97663b5fef 100644 --- a/datafusion/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs @@ -631,6 +631,7 @@ mod test { avg, binary_expr, col, lit, sum, LogicalPlanBuilder, Operator, }; use crate::test::*; + use std::iter; fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let optimizer = CommonSubexprEliminate {}; @@ -688,7 +689,7 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( - vec![], + iter::empty::(), vec![ sum(binary_expr( col("a"), @@ -723,7 +724,7 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( - vec![], + iter::empty::(), vec![ binary_expr(lit(1), Operator::Plus, avg(col("a"))), binary_expr(lit(1), Operator::Minus, avg(col("a"))), diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 2d66c5321acf..4fabc4f08f09 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -475,7 +475,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![], vec![max(col("b"))])? + .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\ @@ -508,7 +508,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("c"))? - .aggregate(vec![], vec![max(col("b"))])? + .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\ diff --git a/datafusion/src/pyarrow.rs b/datafusion/src/pyarrow.rs new file mode 100644 index 000000000000..da05d63d8c2c --- /dev/null +++ b/datafusion/src/pyarrow.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::exceptions::{PyException, PyNotImplementedError}; +use pyo3::prelude::*; +use pyo3::types::PyList; +use pyo3::PyNativeType; + +use crate::arrow::array::ArrayData; +use crate::arrow::pyarrow::PyArrowConvert; +use crate::error::DataFusionError; +use crate::scalar::ScalarValue; + +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) + } +} + +impl PyArrowConvert for ScalarValue { + fn from_pyarrow(value: &PyAny) -> PyResult { + let py = value.py(); + let typ = value.getattr("type")?; + let val = value.call_method0("as_py")?; + + // construct pyarrow array from the python value and pyarrow type + let factory = py.import("pyarrow")?.getattr("array")?; + let args = PyList::new(py, &[val]); + let array = factory.call1((args, typ))?; + + // convert the pyarrow array to rust array using C data interface + let array = array.extract::()?; + let scalar = ScalarValue::try_from_array(&array.into(), 0)?; + + Ok(scalar) + } + + fn to_pyarrow(&self, _py: Python) -> PyResult { + Err(PyNotImplementedError::new_err("Not implemented")) + } +} + +impl<'source> FromPyObject<'source> for ScalarValue { + fn extract(value: &'source PyAny) -> PyResult { + Self::from_pyarrow(value) + } +} + +impl<'a> IntoPy for ScalarValue { + fn into_py(self, py: Python) -> PyObject { + self.to_pyarrow(py).unwrap() + } +} diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 00586bf5549e..33bc9dd10486 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -1093,6 +1093,10 @@ impl ScalarValue { DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), + DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), + DataType::LargeBinary => { + typed_cast!(array, index, LargeBinaryArray, LargeBinary) + } DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), DataType::List(nested_type) => { diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 60d2da8be2c7..1653cb5d5ac5 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -18,6 +18,7 @@ //! SQL Query Planner (produces logical plan from SQL AST) use std::collections::HashSet; +use std::iter; use std::str::FromStr; use std::sync::Arc; use std::{convert::TryInto, vec}; @@ -822,7 +823,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = if select.distinct { return LogicalPlanBuilder::from(plan) - .aggregate(select_exprs_post_aggr, vec![])? + .aggregate(select_exprs_post_aggr, iter::empty::())? .build(); } else { plan diff --git a/python/Cargo.lock b/python/Cargo.lock index 6ae27021e61c..fa84a54ced7b 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -10,9 +10,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43bb833f0bf979d8475d38fbf09ed3b8a55e1885fe93ad3f93239fc6a4f17b98" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" dependencies = [ "getrandom 0.2.3", "once_cell", @@ -72,6 +72,7 @@ dependencies = [ "lexical-core", "multiversion", "num", + "pyo3", "rand 0.8.4", "regex", "serde", @@ -121,9 +122,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcd555c66291d5f836dbb6883b48660ece810fe25a31f3bdfb911945dff2691f" +checksum = "2607a74355ce2e252d0c483b2d8a348e1bba36036e786ccc2dcd777213c86ffd" dependencies = [ "arrayref", "arrayvec", @@ -165,9 +166,9 @@ dependencies = [ [[package]] name = "bstr" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90682c8d613ad3373e66de8c6411e0ae2ab2571e879d2efbf73558cc66f21279" +checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" dependencies = [ "lazy_static", "memchr", @@ -183,9 +184,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "cc" -version = "1.0.70" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26a6ce4b6a484fa3edb70f7efa6fc430fd2b87285fe8b84304fd0936faa0dc0" +checksum = "79c2681d6594606957bbb8631c4b90a7fcaaa72cdb714743a437b156d6a7eedd" dependencies = [ "jobserver", ] @@ -296,6 +297,7 @@ dependencies = [ "parquet", "paste 1.0.5", "pin-project-lite", + "pyo3", "rand 0.8.4", "regex", "sha2", @@ -311,7 +313,6 @@ name = "datafusion-python" version = "0.3.0" dependencies = [ "datafusion", - "libc", "pyo3", "rand 0.7.3", "tokio", @@ -340,9 +341,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.21" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80edafed416a46fb378521624fab1cfa2eb514784fd8921adbe8a8d8321da811" +checksum = "1e6988e897c1c9c485f43b47a529cef42fde0547f9d8d41a7062518f1d8fc53f" dependencies = [ "cfg-if", "crc32fast", @@ -544,9 +545,9 @@ dependencies = [ [[package]] name = "instant" -version = "0.1.10" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bee0328b1209d157ef001c94dd85b4f8f64139adb0eac2659f4b08382b2f474d" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" dependencies = [ "cfg-if", ] @@ -644,9 +645,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.101" +version = "0.2.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21" +checksum = "869d572136620d55835903746bcb5cdc54cb2851fd0aeec53220b4bb65ef3013" [[package]] name = "lock_api" @@ -943,9 +944,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "ppv-lite86" -version = "0.2.10" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" +checksum = "ed0cfbc8191465bed66e1718596ee0b0b35d5ee1f41c5df2189d0fe8bde535ba" [[package]] name = "proc-macro-hack" @@ -961,9 +962,9 @@ checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" [[package]] name = "proc-macro2" -version = "1.0.28" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c7ed8b8c7b886ea3ed7dde405212185f423ab44682667c8c6dd14aa1d9f6612" +checksum = "edc3358ebc67bc8b7fa0c007f945b0b18226f78437d61bec735a9eb96b61ee70" dependencies = [ "unicode-xid", ] @@ -1018,9 +1019,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" +checksum = "38bc8cc6a5f2e3655e0899c1b848643b2562f853f114bfec7be120678e3ace05" dependencies = [ "proc-macro2", ] @@ -1169,9 +1170,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.67" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7f9e390c27c3c0ce8bc5d725f6e4d30a29d26659494aa4b17535f7522c5c950" +checksum = "0f690853975602e1bfe1ccbf50504d67174e3bcf340f23b5ea9992e0587a52d8" dependencies = [ "indexmap", "itoa", @@ -1181,9 +1182,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.9.6" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9204c41a1597a8c5af23c82d1c921cb01ec0a4c59e07a9c7306062829a3903f3" +checksum = "b69f9a4c9740d74c5baa3fd2e547f9525fa8088a8a958e0ca2409a514e33f5fa" dependencies = [ "block-buffer", "cfg-if", @@ -1194,15 +1195,15 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c307a32c1c5c437f38c7fd45d753050587732ba8628319fbdf12a7e289ccc590" +checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" [[package]] name = "smallvec" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" +checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309" [[package]] name = "snap" @@ -1251,9 +1252,9 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "syn" -version = "1.0.76" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6f107db402c2c2055242dbf4d2af0e69197202e9faacbef9571bbe47f5a1b84" +checksum = "d010a1623fbd906d51d650a9916aaefc05ffa0e4053ff7fe601167f3e715d194" dependencies = [ "proc-macro2", "quote", @@ -1262,18 +1263,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.29" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "602eca064b2d83369e2b2f34b09c70b605402801927c65c11071ac911d299b88" +checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.29" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bad553cc2c78e8de258400763a647e80e6d1b31ee237275d756f6836d204494c" +checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" dependencies = [ "proc-macro2", "quote", @@ -1314,9 +1315,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4efe6fc2395938c8155973d7be49fe8d03a843726e285e100a8a383cc0154ce" +checksum = "c2c2416fdedca8443ae44b4527de1ea633af61d8f7169ffa6e72c5b53d24efcc" dependencies = [ "autocfg", "num_cpus", @@ -1326,9 +1327,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "1.3.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" +checksum = "b2dd85aeaba7b68df939bd357c6afb36c87951be9e80bf9c859f2fc3e9fca0fd" dependencies = [ "proc-macro2", "quote", @@ -1360,9 +1361,9 @@ checksum = "8895849a949e7845e06bd6dc1aa51731a103c42707010a5b591c0038fb73385b" [[package]] name = "unicode-width" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3" +checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" [[package]] name = "unicode-xid" diff --git a/python/Cargo.toml b/python/Cargo.toml index c0645a152078..3d3ebfa34540 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -28,17 +28,19 @@ edition = "2021" rust-version = "1.56" [dependencies] -libc = "0.2" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } rand = "0.7" -pyo3 = { version = "0.14.1", features = ["extension-module", "abi3", "abi3-py36"] } -datafusion = { path = "../datafusion", version = "5.1.0" } +pyo3 = { version = "0.14", features = ["extension-module", "abi3", "abi3-py36"] } +datafusion = { path = "../datafusion", version = "5.1.0", features = ["pyarrow"] } uuid = { version = "0.8", features = ["v4"] } [lib] -name = "datafusion" +name = "_internal" crate-type = ["cdylib"] +[package.metadata.maturin] +name = "datafusion._internal" + [profile.release] lto = true codegen-units = 1 diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py new file mode 100644 index 000000000000..4f9082e7e402 --- /dev/null +++ b/python/datafusion/__init__.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from abc import ABCMeta, abstractmethod +from typing import List + +import pyarrow as pa + +from ._internal import ( + AggregateUDF, + DataFrame, + ExecutionContext, + Expression, + ScalarUDF, +) + +__all__ = [ + "DataFrame", + "ExecutionContext", + "Expression", + "AggregateUDF", + "ScalarUDF", + "column", + "literal", +] + + +class Accumulator(metaclass=ABCMeta): + @abstractmethod + def state(self) -> List[pa.Scalar]: + pass + + @abstractmethod + def update(self, values: pa.Array) -> None: + pass + + @abstractmethod + def merge(self, states: pa.Array) -> None: + pass + + @abstractmethod + def evaluate(self) -> pa.Scalar: + pass + + +def column(value): + return Expression.column(value) + + +def literal(value): + if not isinstance(value, pa.Scalar): + value = pa.scalar(value) + return Expression.literal(value) + + +def udf(func, input_types, return_type, volatility, name=None): + """ + Create a new User Defined Function + """ + if not callable(func): + raise TypeError("`func` argument must be callable") + if name is None: + name = func.__qualname__ + return ScalarUDF( + name=name, + func=func, + input_types=input_types, + return_type=return_type, + volatility=volatility, + ) + + +def udaf(accum, input_type, return_type, state_type, volatility, name=None): + """ + Create a new User Defined Aggregate Function + """ + if not issubclass(accum, Accumulator): + raise TypeError( + "`accum` must implement the abstract base class Accumulator" + ) + if name is None: + name = accum.__qualname__ + return AggregateUDF( + name=name, + accumulator=accum, + input_type=input_type, + return_type=return_type, + state_type=state_type, + volatility=volatility, + ) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py new file mode 100644 index 000000000000..782ecba22191 --- /dev/null +++ b/python/datafusion/functions.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from ._internal import functions + + +def __getattr__(name): + return getattr(functions, name) diff --git a/python/tests/__init__.py b/python/datafusion/tests/__init__.py similarity index 100% rename from python/tests/__init__.py rename to python/datafusion/tests/__init__.py diff --git a/python/tests/generic.py b/python/datafusion/tests/generic.py similarity index 100% rename from python/tests/generic.py rename to python/datafusion/tests/generic.py diff --git a/python/tests/test_aggregation.py b/python/datafusion/tests/test_aggregation.py similarity index 94% rename from python/tests/test_aggregation.py rename to python/datafusion/tests/test_aggregation.py index f0996f9e06d9..d539c44585a6 100644 --- a/python/tests/test_aggregation.py +++ b/python/datafusion/tests/test_aggregation.py @@ -17,7 +17,8 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext + +from datafusion import ExecutionContext, column from datafusion import functions as f @@ -34,8 +35,8 @@ def df(): def test_built_in_aggregation(df): - col_a = f.col("a") - col_b = f.col("b") + col_a = column("a") + col_b = column("b") df = df.aggregate( [], [f.max(col_a), f.min(col_a), f.count(col_a), f.approx_distinct(col_b)], diff --git a/python/datafusion/tests/test_catalog.py b/python/datafusion/tests/test_catalog.py new file mode 100644 index 000000000000..2e64a810a718 --- /dev/null +++ b/python/datafusion/tests/test_catalog.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pyarrow as pa +import pytest + +from datafusion import ExecutionContext + + +@pytest.fixture +def ctx(): + return ExecutionContext() + + +@pytest.fixture +def database(ctx, tmp_path): + path = tmp_path / "test.csv" + + table = pa.Table.from_arrays( + [ + [1, 2, 3, 4], + ["a", "b", "c", "d"], + [1.1, 2.2, 3.3, 4.4], + ], + names=["int", "str", "float"], + ) + pa.csv.write_csv(table, path) + + ctx.register_csv("csv", path) + ctx.register_csv("csv1", str(path)) + ctx.register_csv( + "csv2", + path, + has_header=True, + delimiter=",", + schema_infer_max_records=10, + ) + + +def test_basic(ctx, database): + with pytest.raises(KeyError): + ctx.catalog("non-existent") + + default = ctx.catalog() + assert default.names() == ["public"] + + for database in [default.database("public"), default.database()]: + assert database.names() == {"csv1", "csv", "csv2"} + + table = database.table("csv") + assert table.kind == "physical" + assert table.schema == pa.schema( + [ + pa.field("int", pa.int64(), nullable=False), + pa.field("str", pa.string(), nullable=False), + pa.field("float", pa.float64(), nullable=False), + ] + ) diff --git a/python/tests/test_df_sql.py b/python/datafusion/tests/test_context.py similarity index 99% rename from python/tests/test_df_sql.py rename to python/datafusion/tests/test_context.py index c6eac6bb2ffc..60beea4a01be 100644 --- a/python/tests/test_df_sql.py +++ b/python/datafusion/tests/test_context.py @@ -17,6 +17,7 @@ import pyarrow as pa import pytest + from datafusion import ExecutionContext diff --git a/python/tests/test_df.py b/python/datafusion/tests/test_dataframe.py similarity index 75% rename from python/tests/test_df.py rename to python/datafusion/tests/test_dataframe.py index 9bbdb5a30077..0eb970a69e83 100644 --- a/python/tests/test_df.py +++ b/python/datafusion/tests/test_dataframe.py @@ -17,8 +17,8 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext -from datafusion import functions as f + +from datafusion import DataFrame, ExecutionContext, column, literal, udf @pytest.fixture @@ -36,8 +36,8 @@ def df(): def test_select(df): df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), + column("a") + column("b"), + column("a") - column("b"), ) # execute and collect the first (and only) batch @@ -49,9 +49,9 @@ def test_select(df): def test_filter(df): df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), - ).filter(f.col("a") > f.lit(2)) + column("a") + column("b"), + column("a") - column("b"), + ).filter(column("a") > literal(2)) # execute and collect the first (and only) batch result = df.collect()[0] @@ -61,7 +61,7 @@ def test_filter(df): def test_sort(df): - df = df.sort([f.col("b").sort(ascending=False)]) + df = df.sort(column("b").sort(ascending=False)) table = pa.Table.from_batches(df.collect()) expected = {"a": [3, 2, 1], "b": [6, 5, 4]} @@ -81,14 +81,14 @@ def test_limit(df): def test_udf(df): # is_null is a pa function over arrays - udf = f.udf( + is_null = udf( lambda x: x.is_null(), [pa.int64()], pa.bool_(), - f.Volatility.immutable(), + volatility="immutable", ) - df = df.select(udf(f.col("a"))) + df = df.select(is_null(column("a"))) result = df.collect()[0].column(0) assert result == pa.array([False, False, False]) @@ -110,8 +110,28 @@ def test_join(): df1 = ctx.create_dataframe([[batch]]) df = df.join(df1, join_keys=(["a"], ["a"]), how="inner") - df = df.sort([f.col("a").sort(ascending=True)]) + df = df.sort(column("a").sort(ascending=True)) table = pa.Table.from_batches(df.collect()) expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} assert table.to_pydict() == expected + + +def test_get_dataframe(tmp_path): + ctx = ExecutionContext() + + path = tmp_path / "test.csv" + table = pa.Table.from_arrays( + [ + [1, 2, 3, 4], + ["a", "b", "c", "d"], + [1.1, 2.2, 3.3, 4.4], + ], + names=["int", "str", "float"], + ) + pa.csv.write_csv(table, path) + + ctx.register_csv("csv", path) + + df = ctx.table("csv") + assert isinstance(df, DataFrame) diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py new file mode 100644 index 000000000000..84718eaf0ce6 --- /dev/null +++ b/python/datafusion/tests/test_functions.py @@ -0,0 +1,219 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pyarrow as pa +import pytest + +from datafusion import ExecutionContext, column +from datafusion import functions as f +from datafusion import literal + + +@pytest.fixture +def df(): + ctx = ExecutionContext() + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]]) + + +def test_literal(df): + df = df.select( + literal(1), + literal("1"), + literal("OK"), + literal(3.14), + literal(True), + literal(b"hello world"), + ) + result = df.collect() + assert len(result) == 1 + result = result[0] + assert result.column(0) == pa.array([1] * 3) + assert result.column(1) == pa.array(["1"] * 3) + assert result.column(2) == pa.array(["OK"] * 3) + assert result.column(3) == pa.array([3.14] * 3) + assert result.column(4) == pa.array([True] * 3) + assert result.column(5) == pa.array([b"hello world"] * 3) + + +def test_lit_arith(df): + """ + Test literals with arithmetic operations + """ + df = df.select( + literal(1) + column("b"), f.concat(column("a"), literal("!")) + ) + result = df.collect() + assert len(result) == 1 + result = result[0] + assert result.column(0) == pa.array([5, 6, 7]) + assert result.column(1) == pa.array(["Hello!", "World!", "!!"]) + + +def test_math_functions(): + ctx = ExecutionContext() + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([0.1, -0.7, 0.55])], names=["value"] + ) + df = ctx.create_dataframe([[batch]]) + + values = np.array([0.1, -0.7, 0.55]) + col_v = column("value") + df = df.select( + f.abs(col_v), + f.sin(col_v), + f.cos(col_v), + f.tan(col_v), + f.asin(col_v), + f.acos(col_v), + f.exp(col_v), + f.ln(col_v + literal(pa.scalar(1))), + f.log2(col_v + literal(pa.scalar(1))), + f.log10(col_v + literal(pa.scalar(1))), + f.random(), + ) + batches = df.collect() + assert len(batches) == 1 + result = batches[0] + + np.testing.assert_array_almost_equal(result.column(0), np.abs(values)) + np.testing.assert_array_almost_equal(result.column(1), np.sin(values)) + np.testing.assert_array_almost_equal(result.column(2), np.cos(values)) + np.testing.assert_array_almost_equal(result.column(3), np.tan(values)) + np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values)) + np.testing.assert_array_almost_equal(result.column(5), np.arccos(values)) + np.testing.assert_array_almost_equal(result.column(6), np.exp(values)) + np.testing.assert_array_almost_equal( + result.column(7), np.log(values + 1.0) + ) + np.testing.assert_array_almost_equal( + result.column(8), np.log2(values + 1.0) + ) + np.testing.assert_array_almost_equal( + result.column(9), np.log10(values + 1.0) + ) + np.testing.assert_array_less(result.column(10), np.ones_like(values)) + + +def test_string_functions(df): + df = df.select(f.md5(column("a")), f.lower(column("a"))) + result = df.collect() + assert len(result) == 1 + result = result[0] + assert result.column(0) == pa.array( + [ + "8b1a9953c4611296a827abf8c47804d7", + "f5a7924e621e84c9280a9a27e1bcb7f6", + "9033e0e305f247c0c3c80d0c7848c8b3", + ] + ) + assert result.column(1) == pa.array(["hello", "world", "!"]) + + +def test_hash_functions(df): + exprs = [ + f.digest(column("a"), literal(m)) + for m in ("md5", "sha256", "sha512", "blake2s", "blake3") + ] + df = df.select(*exprs) + result = df.collect() + assert len(result) == 1 + result = result[0] + b = bytearray.fromhex + assert result.column(0) == pa.array( + [ + b("8B1A9953C4611296A827ABF8C47804D7"), + b("F5A7924E621E84C9280A9A27E1BCB7F6"), + b("9033E0E305F247C0C3C80D0C7848C8B3"), + ] + ) + assert result.column(1) == pa.array( + [ + b( + "185F8DB32271FE25F561A6FC938B2E26" + "4306EC304EDA518007D1764826381969" + ), + b( + "78AE647DC5544D227130A0682A51E30B" + "C7777FBB6D8A8F17007463A3ECD1D524" + ), + b( + "BB7208BC9B5D7C04F1236A82A0093A5E" + "33F40423D5BA8D4266F7092C3BA43B62" + ), + ] + ) + assert result.column(2) == pa.array( + [ + b( + "3615F80C9D293ED7402687F94B22D58E" + "529B8CC7916F8FAC7FDDF7FBD5AF4CF7" + "77D3D795A7A00A16BF7E7F3FB9561EE9" + "BAAE480DA9FE7A18769E71886B03F315" + ), + b( + "8EA77393A42AB8FA92500FB077A9509C" + "C32BC95E72712EFA116EDAF2EDFAE34F" + "BB682EFDD6C5DD13C117E08BD4AAEF71" + "291D8AACE2F890273081D0677C16DF0F" + ), + b( + "3831A6A6155E509DEE59A7F451EB3532" + "4D8F8F2DF6E3708894740F98FDEE2388" + "9F4DE5ADB0C5010DFB555CDA77C8AB5D" + "C902094C52DE3278F35A75EBC25F093A" + ), + ] + ) + assert result.column(3) == pa.array( + [ + b( + "F73A5FBF881F89B814871F46E26AD3FA" + "37CB2921C5E8561618639015B3CCBB71" + ), + b( + "B792A0383FB9E7A189EC150686579532" + "854E44B71AC394831DAED169BA85CCC5" + ), + b( + "27988A0E51812297C77A433F63523334" + "6AEE29A829DCF4F46E0F58F402C6CFCB" + ), + ] + ) + assert result.column(4) == pa.array( + [ + b( + "FBC2B0516EE8744D293B980779178A35" + "08850FDCFE965985782C39601B65794F" + ), + b( + "BF73D18575A736E4037D45F9E316085B" + "86C19BE6363DE6AA789E13DEAACC1C4E" + ), + b( + "C8D11B9F7237E4034ADBCD2005735F9B" + "C4C597C75AD89F4492BEC8F77D15F7EB" + ), + ] + ) diff --git a/python/datafusion/tests/test_imports.py b/python/datafusion/tests/test_imports.py new file mode 100644 index 000000000000..423800248a5c --- /dev/null +++ b/python/datafusion/tests/test_imports.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import datafusion +from datafusion import ( + AggregateUDF, + DataFrame, + ExecutionContext, + Expression, + ScalarUDF, + functions, +) + + +def test_import_datafusion(): + assert datafusion.__name__ == "datafusion" + + +def test_class_module_is_datafusion(): + for klass in [ + ExecutionContext, + Expression, + DataFrame, + ScalarUDF, + AggregateUDF, + ]: + assert klass.__module__ == "datafusion" + + +def test_import_from_functions_submodule(): + from datafusion.functions import abs, sin # noqa + + assert functions.abs is abs + assert functions.sin is sin + + msg = "cannot import name 'foobar' from 'datafusion.functions'" + with pytest.raises(ImportError, match=msg): + from datafusion.functions import foobar # noqa + + +def test_classes_are_inheritable(): + class MyExecContext(ExecutionContext): + pass + + class MyExpression(Expression): + pass + + class MyDataFrame(DataFrame): + pass diff --git a/python/tests/test_sql.py b/python/datafusion/tests/test_sql.py similarity index 85% rename from python/tests/test_sql.py rename to python/datafusion/tests/test_sql.py index f309a85104b2..23f20079f0da 100644 --- a/python/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -19,8 +19,8 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext -from datafusion import functions as f +from datafusion import ExecutionContext, udf + from . import generic as helpers @@ -68,9 +68,9 @@ def test_register_csv(ctx, tmp_path): assert ctx.tables() == {"csv", "csv1", "csv2", "csv3"} for table in ["csv", "csv1", "csv2"]: - result = ctx.sql(f"SELECT COUNT(int) FROM {table}").collect() + result = ctx.sql(f"SELECT COUNT(int) AS cnt FROM {table}").collect() result = pa.Table.from_batches(result) - assert result.to_pydict() == {f"COUNT({table}.int)": [4]} + assert result.to_pydict() == {"cnt": [4]} result = ctx.sql("SELECT * FROM csv3").collect() result = pa.Table.from_batches(result) @@ -87,9 +87,9 @@ def test_register_parquet(ctx, tmp_path): ctx.register_parquet("t", path) assert ctx.tables() == {"t"} - result = ctx.sql("SELECT COUNT(a) FROM t").collect() + result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() result = pa.Table.from_batches(result) - assert result.to_pydict() == {"COUNT(t.a)": [100]} + assert result.to_pydict() == {"cnt": [100]} def test_execute(ctx, tmp_path): @@ -102,21 +102,21 @@ def test_execute(ctx, tmp_path): assert ctx.tables() == {"t"} # count - result = ctx.sql("SELECT COUNT(a) FROM t").collect() + result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() expected = pa.array([7], pa.uint64()) - expected = [pa.RecordBatch.from_arrays([expected], ["COUNT(a)"])] + expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])] assert result == expected # where expected = pa.array([2], pa.uint64()) - expected = [pa.RecordBatch.from_arrays([expected], ["COUNT(a)"])] - result = ctx.sql("SELECT COUNT(a) FROM t WHERE a > 10").collect() + expected = [pa.RecordBatch.from_arrays([expected], ["cnt"])] + result = ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a > 10").collect() assert result == expected # group by results = ctx.sql( - "SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)" + "SELECT CAST(a as int) AS a, COUNT(a) AS cnt FROM t GROUP BY a" ).collect() # group by returns batches @@ -124,8 +124,8 @@ def test_execute(ctx, tmp_path): result_values = [] for result in results: pydict = result.to_pydict() - result_keys.extend(pydict["CAST(t.a AS Int32)"]) - result_values.extend(pydict["COUNT(t.a)"]) + result_keys.extend(pydict["a"]) + result_values.extend(pydict["cnt"]) result_keys, result_values = ( list(t) for t in zip(*sorted(zip(result_keys, result_values))) @@ -136,14 +136,12 @@ def test_execute(ctx, tmp_path): # order by result = ctx.sql( - "SELECT a, CAST(a AS int) FROM t ORDER BY a DESC LIMIT 2" + "SELECT a, CAST(a AS int) AS a_int FROM t ORDER BY a DESC LIMIT 2" ).collect() expected_a = pa.array([50.0219, 50.0152], pa.float64()) expected_cast = pa.array([50, 50], pa.int32()) expected = [ - pa.RecordBatch.from_arrays( - [expected_a, expected_cast], ["a", "CAST(t.a AS Int32)"] - ) + pa.RecordBatch.from_arrays([expected_a, expected_cast], ["a", "a_int"]) ] np.testing.assert_equal(expected[0].column(1), expected[0].column(1)) @@ -199,11 +197,13 @@ def test_udf( tmp_path / "a.parquet", pa.array(input_values) ) ctx.register_parquet("t", path) - ctx.register_udf( - "udf", fn, input_types, output_type, f.Volatility.immutable() + + func = udf( + fn, input_types, output_type, name="func", volatility="immutable" ) + ctx.register_udf(func) - batches = ctx.sql("SELECT udf(a) AS tt FROM t").collect() + batches = ctx.sql("SELECT func(a) AS tt FROM t").collect() result = batches[0].column(0) assert result == pa.array(expected_values) diff --git a/python/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py similarity index 65% rename from python/tests/test_udaf.py rename to python/datafusion/tests/test_udaf.py index 7ff622330ccc..2f286ba105dd 100644 --- a/python/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -20,11 +20,11 @@ import pyarrow as pa import pyarrow.compute as pc import pytest -from datafusion import ExecutionContext -from datafusion import functions as f +from datafusion import Accumulator, ExecutionContext, column, udaf -class Accumulator: + +class Summarize(Accumulator): """ Interface of a user-defined accumulation. """ @@ -32,7 +32,7 @@ class Accumulator: def __init__(self): self._sum = pa.scalar(0.0) - def to_scalars(self) -> List[pa.Scalar]: + def state(self) -> List[pa.Scalar]: return [self._sum] def update(self, values: pa.Array) -> None: @@ -49,6 +49,18 @@ def evaluate(self) -> pa.Scalar: return self._sum +class NotSubclassOfAccumulator: + pass + + +class MissingMethods(Accumulator): + def __init__(self): + self._sum = pa.scalar(0) + + def state(self) -> List[pa.Scalar]: + return [self._sum] + + @pytest.fixture def df(): ctx = ExecutionContext() @@ -61,16 +73,43 @@ def df(): return ctx.create_dataframe([[batch]]) +def test_errors(df): + with pytest.raises(TypeError): + udaf( + NotSubclassOfAccumulator, + pa.float64(), + pa.float64(), + [pa.float64()], + volatility="immutable", + ) + + accum = udaf( + MissingMethods, + pa.int64(), + pa.int64(), + [pa.int64()], + volatility="immutable", + ) + df = df.aggregate([], [accum(column("a"))]) + + msg = ( + "Can't instantiate abstract class MissingMethods with abstract " + "methods evaluate, merge, update" + ) + with pytest.raises(Exception, match=msg): + df.collect() + + def test_aggregate(df): - udaf = f.udaf( - Accumulator, + summarize = udaf( + Summarize, pa.float64(), pa.float64(), [pa.float64()], - f.Volatility.immutable(), + volatility="immutable", ) - df = df.aggregate([], [udaf(f.col("a"))]) + df = df.aggregate([], [summarize(column("a"))]) # execute and collect the first (and only) batch result = df.collect()[0] @@ -79,17 +118,18 @@ def test_aggregate(df): def test_group_by(df): - udaf = f.udaf( - Accumulator, + summarize = udaf( + Summarize, pa.float64(), pa.float64(), [pa.float64()], - f.Volatility.immutable(), + volatility="immutable", ) - df = df.aggregate([f.col("b")], [udaf(f.col("a"))]) + df = df.aggregate([column("b")], [summarize(column("a"))]) batches = df.collect() + arrays = [batch.column(1) for batch in batches] joined = pa.concat_arrays(arrays) assert joined == pa.array([1.0 + 2.0, 3.0]) diff --git a/python/pyproject.toml b/python/pyproject.toml index f366aa94ddf4..c6ee363497d7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -50,3 +50,6 @@ dependencies = [ [project.urls] documentation = "https://arrow.apache.org/datafusion/python" repository = "https://github.com/apache/arrow-datafusion" + +[tool.isort] +profile = "black" diff --git a/python/src/catalog.rs b/python/src/catalog.rs new file mode 100644 index 000000000000..f93c795ec34c --- /dev/null +++ b/python/src/catalog.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashSet; +use std::sync::Arc; + +use pyo3::exceptions::PyKeyError; +use pyo3::prelude::*; + +use datafusion::{ + arrow::pyarrow::PyArrowConvert, + catalog::{catalog::CatalogProvider, schema::SchemaProvider}, + datasource::{TableProvider, TableType}, +}; + +#[pyclass(name = "Catalog", module = "datafusion", subclass)] +pub(crate) struct PyCatalog { + catalog: Arc, +} + +#[pyclass(name = "Database", module = "datafusion", subclass)] +pub(crate) struct PyDatabase { + database: Arc, +} + +#[pyclass(name = "Table", module = "datafusion", subclass)] +pub(crate) struct PyTable { + table: Arc, +} + +impl PyCatalog { + pub fn new(catalog: Arc) -> Self { + Self { catalog } + } +} + +impl PyDatabase { + pub fn new(database: Arc) -> Self { + Self { database } + } +} + +impl PyTable { + pub fn new(table: Arc) -> Self { + Self { table } + } +} + +#[pymethods] +impl PyCatalog { + fn names(&self) -> Vec { + self.catalog.schema_names() + } + + #[args(name = "\"public\"")] + fn database(&self, name: &str) -> PyResult { + match self.catalog.schema(name) { + Some(database) => Ok(PyDatabase::new(database)), + None => Err(PyKeyError::new_err(format!( + "Database with name {} doesn't exist.", + name + ))), + } + } +} + +#[pymethods] +impl PyDatabase { + fn names(&self) -> HashSet { + self.database.table_names().into_iter().collect() + } + + fn table(&self, name: &str) -> PyResult { + match self.database.table(name) { + Some(table) => Ok(PyTable::new(table)), + None => Err(PyKeyError::new_err(format!( + "Table with name {} doesn't exist.", + name + ))), + } + } + + // register_table + // deregister_table +} + +#[pymethods] +impl PyTable { + /// Get a reference to the schema for this table + #[getter] + fn schema(&self, py: Python) -> PyResult { + self.table.schema().to_pyarrow(py) + } + + /// Get the type of this table for metadata/catalog purposes. + #[getter] + fn kind(&self) -> &str { + match self.table.table_type() { + TableType::Base => "physical", + TableType::View => "view", + TableType::Temporary => "temporary", + } + } + + // fn scan + // fn statistics + // fn has_exact_statistics + // fn supports_filter_pushdown +} diff --git a/python/src/context.rs b/python/src/context.rs index b813f27a73c9..7f386bac398d 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -20,73 +20,52 @@ use std::{collections::HashSet, sync::Arc}; use uuid::Uuid; -use tokio::runtime::Runtime; - -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; +use datafusion::arrow::datatypes::Schema; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; -use datafusion::execution::context::ExecutionContext as _ExecutionContext; +use datafusion::execution::context::ExecutionContext; use datafusion::prelude::CsvReadOptions; -use crate::dataframe; -use crate::errors; -use crate::functions::{self, PyVolatility}; -use crate::to_rust; -use crate::types::PyDataType; +use crate::catalog::PyCatalog; +use crate::dataframe::PyDataFrame; +use crate::errors::DataFusionError; +use crate::udf::PyScalarUDF; +use crate::utils::wait_for_future; -/// `ExecutionContext` is able to plan and execute DataFusion plans. +/// `PyExecutionContext` is able to plan and execute DataFusion plans. /// It has a powerful optimizer, a physical planner for local execution, and a /// multi-threaded execution engine to perform the execution. -#[pyclass(unsendable)] -pub(crate) struct ExecutionContext { - ctx: _ExecutionContext, +#[pyclass(name = "ExecutionContext", module = "datafusion", subclass, unsendable)] +pub(crate) struct PyExecutionContext { + ctx: ExecutionContext, } #[pymethods] -impl ExecutionContext { +impl PyExecutionContext { + // TODO(kszucs): should expose the configuration options as keyword arguments #[new] fn new() -> Self { - ExecutionContext { - ctx: _ExecutionContext::new(), + PyExecutionContext { + ctx: ExecutionContext::new(), } } - /// Returns a DataFrame whose plan corresponds to the SQL statement. - fn sql(&mut self, query: &str, py: Python) -> PyResult { - let rt = Runtime::new().unwrap(); - let df = py.allow_threads(|| { - rt.block_on(async { - self.ctx - .sql(query) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) - })?; - Ok(dataframe::DataFrame::new( - self.ctx.state.clone(), - df.to_logical_plan(), - )) + /// Returns a PyDataFrame whose plan corresponds to the SQL statement. + fn sql(&mut self, query: &str, py: Python) -> PyResult { + let result = self.ctx.sql(query); + let df = wait_for_future(py, result).map_err(DataFusionError::from)?; + Ok(PyDataFrame::new(df)) } fn create_dataframe( &mut self, - partitions: Vec>, - py: Python, - ) -> PyResult { - let partitions: Vec> = partitions - .iter() - .map(|batches| { - batches - .iter() - .map(|batch| to_rust::to_rust_batch(batch.as_ref(py))) - .collect() - }) - .collect::>()?; - - let table = - errors::wrap(MemTable::try_new(partitions[0][0].schema(), partitions))?; + partitions: Vec>, + ) -> PyResult { + let table = MemTable::try_new(partitions[0][0].schema(), partitions) + .map_err(DataFusionError::from)?; // generate a random (unique) name for this table // table name cannot start with numeric digit @@ -95,43 +74,31 @@ impl ExecutionContext { .to_simple() .encode_lower(&mut Uuid::encode_buffer()); - errors::wrap(self.ctx.register_table(&*name, Arc::new(table)))?; - Ok(dataframe::DataFrame::new( - self.ctx.state.clone(), - errors::wrap(self.ctx.table(&*name))?.to_logical_plan(), - )) + self.ctx + .register_table(&*name, Arc::new(table)) + .map_err(DataFusionError::from)?; + let table = self.ctx.table(&*name).map_err(DataFusionError::from)?; + + let df = PyDataFrame::new(table); + Ok(df) } fn register_record_batches( &mut self, name: &str, - partitions: Vec>, - py: Python, + partitions: Vec>, ) -> PyResult<()> { - let partitions: Vec> = partitions - .iter() - .map(|batches| { - batches - .iter() - .map(|batch| to_rust::to_rust_batch(batch.as_ref(py))) - .collect() - }) - .collect::>()?; - - let table = - errors::wrap(MemTable::try_new(partitions[0][0].schema(), partitions))?; - - errors::wrap(self.ctx.register_table(&*name, Arc::new(table)))?; + let schema = partitions[0][0].schema(); + let table = MemTable::try_new(schema, partitions)?; + self.ctx + .register_table(name, Arc::new(table)) + .map_err(DataFusionError::from)?; Ok(()) } fn register_parquet(&mut self, name: &str, path: &str, py: Python) -> PyResult<()> { - let rt = Runtime::new().unwrap(); - py.allow_threads(|| { - rt.block_on(async { - errors::wrap(self.ctx.register_parquet(name, path).await) - }) - })?; + let result = self.ctx.register_parquet(name, path); + wait_for_future(py, result).map_err(DataFusionError::from)?; Ok(()) } @@ -146,7 +113,7 @@ impl ExecutionContext { &mut self, name: &str, path: PathBuf, - schema: Option<&PyAny>, + schema: Option, has_header: bool, delimiter: &str, schema_infer_max_records: usize, @@ -156,10 +123,6 @@ impl ExecutionContext { let path = path .to_str() .ok_or(PyValueError::new_err("Unable to convert path to a string"))?; - let schema = match schema { - Some(s) => Some(to_rust::to_rust_schema(s)?), - None => None, - }; let delimiter = delimiter.as_bytes(); if delimiter.len() != 1 { return Err(PyValueError::new_err( @@ -174,30 +137,37 @@ impl ExecutionContext { .file_extension(file_extension); options.schema = schema.as_ref(); - let rt = Runtime::new().unwrap(); - py.allow_threads(|| { - rt.block_on(async { - errors::wrap(self.ctx.register_csv(name, path, options).await) - }) - })?; + let result = self.ctx.register_csv(name, path, options); + wait_for_future(py, result).map_err(DataFusionError::from)?; + Ok(()) } - fn register_udf( - &mut self, - name: &str, - func: PyObject, - args_types: Vec, - return_type: PyDataType, - volatility: PyVolatility, - ) { - let function = - functions::create_udf(func, args_types, return_type, volatility, name); - - self.ctx.register_udf(function.function); + fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> { + self.ctx.register_udf(udf.function); + Ok(()) + } + + #[args(name = "\"datafusion\"")] + fn catalog(&self, name: &str) -> PyResult { + match self.ctx.catalog(name) { + Some(catalog) => Ok(PyCatalog::new(catalog)), + None => Err(PyKeyError::new_err(format!( + "Catalog with name {} doesn't exist.", + &name + ))), + } } fn tables(&self) -> HashSet { self.ctx.tables().unwrap() } + + fn table(&self, name: &str) -> PyResult { + Ok(PyDataFrame::new(self.ctx.table(name)?)) + } + + fn empty_table(&self) -> PyResult { + Ok(PyDataFrame::new(self.ctx.read_empty()?)) + } } diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs index 48da234fc23e..9050df92ed26 100644 --- a/python/src/dataframe.rs +++ b/python/src/dataframe.rs @@ -15,174 +15,97 @@ // specific language governing permissions and limitations // under the License. -use std::sync::{Arc, Mutex}; +use std::sync::Arc; -use logical_plan::LogicalPlan; -use pyo3::{prelude::*, types::PyTuple}; -use tokio::runtime::Runtime; +use pyo3::prelude::*; -use datafusion::execution::context::ExecutionContext as _ExecutionContext; -use datafusion::logical_plan::{JoinType, LogicalPlanBuilder}; -use datafusion::physical_plan::collect; -use datafusion::{execution::context::ExecutionContextState, logical_plan}; - -use crate::{errors, to_py}; -use crate::{errors::DataFusionError, expression}; +use datafusion::arrow::datatypes::Schema; +use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::arrow::util::pretty; +use datafusion::dataframe::DataFrame; +use datafusion::logical_plan::JoinType; + +use crate::utils::wait_for_future; +use crate::{errors::DataFusionError, expression::PyExpr}; -/// A DataFrame is a representation of a logical plan and an API to compose statements. +/// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. -#[pyclass] -pub(crate) struct DataFrame { - ctx_state: Arc>, - plan: LogicalPlan, +#[pyclass(name = "DataFrame", module = "datafusion", subclass)] +#[derive(Clone)] +pub(crate) struct PyDataFrame { + df: Arc, } -impl DataFrame { - /// creates a new DataFrame - pub fn new(ctx_state: Arc>, plan: LogicalPlan) -> Self { - Self { ctx_state, plan } +impl PyDataFrame { + /// creates a new PyDataFrame + pub fn new(df: Arc) -> Self { + Self { df } } } #[pymethods] -impl DataFrame { - /// Select `expressions` from the existing DataFrame. - #[args(args = "*")] - fn select(&self, args: &PyTuple) -> PyResult { - let expressions = expression::from_tuple(args)?; - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = - errors::wrap(builder.project(expressions.into_iter().map(|e| e.expr)))?; - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) +impl PyDataFrame { + /// Returns the schema from the logical plan + fn schema(&self) -> Schema { + self.df.schema().into() } - /// Filter according to the `predicate` expression - fn filter(&self, predicate: expression::Expression) -> PyResult { - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.filter(predicate.expr))?; - let plan = errors::wrap(builder.build())?; + #[args(args = "*")] + fn select(&self, args: Vec) -> PyResult { + let expr = args.into_iter().map(|e| e.into()).collect(); + let df = self.df.select(expr)?; + Ok(Self::new(df)) + } - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) + fn filter(&self, predicate: PyExpr) -> PyResult { + let df = self.df.filter(predicate.into())?; + Ok(Self::new(df)) } - /// Aggregates using expressions - fn aggregate( - &self, - group_by: Vec, - aggs: Vec, - ) -> PyResult { - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.aggregate( - group_by.into_iter().map(|e| e.expr), - aggs.into_iter().map(|e| e.expr), - ))?; - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) + fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyResult { + let group_by = group_by.into_iter().map(|e| e.into()).collect(); + let aggs = aggs.into_iter().map(|e| e.into()).collect(); + let df = self.df.aggregate(group_by, aggs)?; + Ok(Self::new(df)) } - /// Sort by specified sorting expressions - fn sort(&self, exprs: Vec) -> PyResult { - let exprs = exprs.into_iter().map(|e| e.expr); - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.sort(exprs))?; - let plan = errors::wrap(builder.build())?; - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) + #[args(exprs = "*")] + fn sort(&self, exprs: Vec) -> PyResult { + let exprs = exprs.into_iter().map(|e| e.into()).collect(); + let df = self.df.sort(exprs)?; + Ok(Self::new(df)) } - /// Limits the plan to return at most `count` rows fn limit(&self, count: usize) -> PyResult { - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let builder = errors::wrap(builder.limit(count))?; - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) + let df = self.df.limit(count)?; + Ok(Self::new(df)) } /// Executes the plan, returning a list of `RecordBatch`es. - /// Unless some order is specified in the plan, there is no guarantee of the order of the result - fn collect(&self, py: Python) -> PyResult { - let ctx = _ExecutionContext::from(self.ctx_state.clone()); - let rt = Runtime::new().unwrap(); - let plan = ctx - .optimize(&self.plan) - .map_err(|e| -> errors::DataFusionError { e.into() })?; - - let plan = py.allow_threads(|| { - rt.block_on(async { - ctx.create_physical_plan(&plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) - })?; - - let batches = py.allow_threads(|| { - rt.block_on(async { - collect(plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) - })?; - to_py::to_py(&batches) + /// Unless some order is specified in the plan, there is no + /// guarantee of the order of the result. + fn collect(&self, py: Python) -> PyResult> { + let batches = wait_for_future(py, self.df.collect())?; + // cannot use PyResult> return type due to + // https://github.com/PyO3/pyo3/issues/1813 + batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect() } /// Print the result, 20 lines by default #[args(num = "20")] fn show(&self, py: Python, num: usize) -> PyResult<()> { - let ctx = _ExecutionContext::from(self.ctx_state.clone()); - let rt = Runtime::new().unwrap(); - let plan = py.allow_threads(|| { - rt.block_on(async { - let l_plan = ctx - .optimize(&self.limit(num)?.plan) - .map_err(|e| -> errors::DataFusionError { e.into() })?; - let p_plan = ctx - .create_physical_plan(&l_plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() })?; - Ok::<_, PyErr>(p_plan) - }) - })?; - - let batches = py.allow_threads(|| { - rt.block_on(async { - collect(plan) - .await - .map_err(|e| -> errors::DataFusionError { e.into() }) - }) - })?; - - Ok(pretty::print_batches(&batches).unwrap()) + let df = self.df.limit(num)?; + let batches = wait_for_future(py, df.collect())?; + Ok(pretty::print_batches(&batches)?) } - /// Returns the join of two DataFrames `on`. fn join( &self, - right: &DataFrame, + right: PyDataFrame, join_keys: (Vec<&str>, Vec<&str>), how: &str, ) -> PyResult { - let builder = LogicalPlanBuilder::from(self.plan.clone()); - let join_type = match how { "inner" => JoinType::Inner, "left" => JoinType::Left, @@ -199,13 +122,9 @@ impl DataFrame { } }; - let builder = errors::wrap(builder.join(&right.plan, join_type, join_keys))?; - - let plan = errors::wrap(builder.build())?; - - Ok(DataFrame { - ctx_state: self.ctx_state.clone(), - plan, - }) + let df = self + .df + .join(right.df, join_type, &join_keys.0, &join_keys.1)?; + Ok(Self::new(df)) } } diff --git a/python/src/errors.rs b/python/src/errors.rs index fbe98037a030..655ed8441cb4 100644 --- a/python/src/errors.rs +++ b/python/src/errors.rs @@ -19,7 +19,7 @@ use core::fmt; use datafusion::arrow::error::ArrowError; use datafusion::error::DataFusionError as InnerDataFusionError; -use pyo3::{exceptions, PyErr}; +use pyo3::{exceptions::PyException, PyErr}; #[derive(Debug)] pub enum DataFusionError { @@ -38,9 +38,9 @@ impl fmt::Display for DataFusionError { } } -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - exceptions::PyException::new_err(err.to_string()) +impl From for DataFusionError { + fn from(err: ArrowError) -> DataFusionError { + DataFusionError::ArrowError(err) } } @@ -50,12 +50,8 @@ impl From for DataFusionError { } } -impl From for DataFusionError { - fn from(err: ArrowError) -> DataFusionError { - DataFusionError::ArrowError(err) +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) } } - -pub(crate) fn wrap(a: Result) -> Result { - Ok(a?) -} diff --git a/python/src/expression.rs b/python/src/expression.rs index 4320b1d14c8b..21cecaa1ccce 100644 --- a/python/src/expression.rs +++ b/python/src/expression.rs @@ -15,156 +15,117 @@ // specific language governing permissions and limitations // under the License. -use pyo3::{ - basic::CompareOp, prelude::*, types::PyTuple, PyNumberProtocol, PyObjectProtocol, -}; +use pyo3::{basic::CompareOp, prelude::*, PyNumberProtocol, PyObjectProtocol}; +use std::convert::{From, Into}; -use datafusion::logical_plan::Expr as _Expr; -use datafusion::physical_plan::udaf::AggregateUDF as _AggregateUDF; -use datafusion::physical_plan::udf::ScalarUDF as _ScalarUDF; +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_plan::{col, lit, Expr}; -/// An expression that can be used on a DataFrame -#[pyclass] +use datafusion::scalar::ScalarValue; + +/// An PyExpr that can be used on a DataFrame +#[pyclass(name = "Expression", module = "datafusion", subclass)] #[derive(Debug, Clone)] -pub(crate) struct Expression { - pub(crate) expr: _Expr, +pub(crate) struct PyExpr { + pub(crate) expr: Expr, } -/// converts a tuple of expressions into a vector of Expressions -pub(crate) fn from_tuple(value: &PyTuple) -> PyResult> { - value - .iter() - .map(|e| e.extract::()) - .collect::>() +impl From for Expr { + fn from(expr: PyExpr) -> Expr { + expr.expr + } +} + +impl Into for Expr { + fn into(self) -> PyExpr { + PyExpr { expr: self } + } } #[pyproto] -impl PyNumberProtocol for Expression { - fn __add__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr + rhs.expr, - }) +impl PyNumberProtocol for PyExpr { + fn __add__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok((lhs.expr + rhs.expr).into()) } - fn __sub__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr - rhs.expr, - }) + fn __sub__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok((lhs.expr - rhs.expr).into()) } - fn __truediv__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr / rhs.expr, - }) + fn __truediv__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok((lhs.expr / rhs.expr).into()) } - fn __mul__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr * rhs.expr, - }) + fn __mul__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok((lhs.expr * rhs.expr).into()) } - fn __and__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr.and(rhs.expr), - }) + fn __mod__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(lhs.expr.clone().modulus(rhs.expr).into()) } - fn __or__(lhs: Expression, rhs: Expression) -> PyResult { - Ok(Expression { - expr: lhs.expr.or(rhs.expr), - }) + fn __and__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(lhs.expr.clone().and(rhs.expr).into()) } - fn __invert__(&self) -> PyResult { - Ok(Expression { - expr: self.expr.clone().not(), - }) + fn __or__(lhs: PyExpr, rhs: PyExpr) -> PyResult { + Ok(lhs.expr.clone().or(rhs.expr).into()) + } + + fn __invert__(&self) -> PyResult { + Ok(self.expr.clone().not().into()) } } #[pyproto] -impl PyObjectProtocol for Expression { - fn __richcmp__(&self, other: Expression, op: CompareOp) -> Expression { - match op { - CompareOp::Lt => Expression { - expr: self.expr.clone().lt(other.expr), - }, - CompareOp::Le => Expression { - expr: self.expr.clone().lt_eq(other.expr), - }, - CompareOp::Eq => Expression { - expr: self.expr.clone().eq(other.expr), - }, - CompareOp::Ne => Expression { - expr: self.expr.clone().not_eq(other.expr), - }, - CompareOp::Gt => Expression { - expr: self.expr.clone().gt(other.expr), - }, - CompareOp::Ge => Expression { - expr: self.expr.clone().gt_eq(other.expr), - }, - } +impl PyObjectProtocol for PyExpr { + fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr { + let expr = match op { + CompareOp::Lt => self.expr.clone().lt(other.expr), + CompareOp::Le => self.expr.clone().lt_eq(other.expr), + CompareOp::Eq => self.expr.clone().eq(other.expr), + CompareOp::Ne => self.expr.clone().not_eq(other.expr), + CompareOp::Gt => self.expr.clone().gt(other.expr), + CompareOp::Ge => self.expr.clone().gt_eq(other.expr), + }; + expr.into() } } #[pymethods] -impl Expression { - /// assign a name to the expression - pub fn alias(&self, name: &str) -> PyResult { - Ok(Expression { - expr: self.expr.clone().alias(name), - }) +impl PyExpr { + #[staticmethod] + pub fn literal(value: ScalarValue) -> PyExpr { + lit(value).into() } - /// Create a sort expression from an existing expression. - #[args(ascending = true, nulls_first = true)] - pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyResult { - Ok(Expression { - expr: self.expr.clone().sort(ascending, nulls_first), - }) + #[staticmethod] + pub fn column(value: &str) -> PyExpr { + col(value).into() } -} - -/// Represents a ScalarUDF -#[pyclass] -#[derive(Debug, Clone)] -pub struct ScalarUDF { - pub(crate) function: _ScalarUDF, -} -#[pymethods] -impl ScalarUDF { - /// creates a new expression with the call of the udf - #[call] - #[args(args = "*")] - fn __call__(&self, args: &PyTuple) -> PyResult { - let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); + /// assign a name to the PyExpr + pub fn alias(&self, name: &str) -> PyExpr { + self.expr.clone().alias(name).into() + } - Ok(Expression { - expr: self.function.call(args), - }) + /// Create a sort PyExpr from an existing PyExpr. + #[args(ascending = true, nulls_first = true)] + pub fn sort(&self, ascending: bool, nulls_first: bool) -> PyExpr { + self.expr.clone().sort(ascending, nulls_first).into() } -} -/// Represents a AggregateUDF -#[pyclass] -#[derive(Debug, Clone)] -pub struct AggregateUDF { - pub(crate) function: _AggregateUDF, -} + pub fn is_null(&self) -> PyExpr { + self.expr.clone().is_null().into() + } -#[pymethods] -impl AggregateUDF { - /// creates a new expression with the call of the udf - #[call] - #[args(args = "*")] - fn __call__(&self, args: &PyTuple) -> PyResult { - let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect(); - - Ok(Expression { - expr: self.function.call(args), - }) + pub fn cast(&self, to: DataType) -> PyExpr { + // self.expr.cast_to() requires DFSchema to validate that the cast + // is supported, omit that for now + let expr = Expr::Cast { + expr: Box::new(self.expr.clone()), + data_type: to, + }; + expr.into() } } diff --git a/python/src/functions.rs b/python/src/functions.rs index e7d141d61171..a2862202602f 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -15,87 +15,37 @@ // specific language governing permissions and limitations // under the License. -use crate::udaf; -use crate::udf; -use crate::{expression, types::PyDataType}; -use datafusion::arrow::datatypes::DataType; -use datafusion::logical_plan::{self, Literal}; -use datafusion::physical_plan::functions::Volatility; -use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction, Python}; -use std::sync::Arc; +use pyo3::{prelude::*, wrap_pyfunction}; -/// Expression representing a column on the existing plan. -#[pyfunction] -#[pyo3(text_signature = "(name)")] -fn col(name: &str) -> expression::Expression { - expression::Expression { - expr: logical_plan::col(name), - } -} +use datafusion::logical_plan; -/// # A bridge type that converts PyAny data into datafusion literal -/// -/// Note that the ordering here matters because it has to be from -/// narrow to wider values because Python has duck typing so putting -/// Int before Boolean results in a premature match. -#[derive(FromPyObject)] -enum PythonLiteral<'a> { - Boolean(bool), - Int(i64), - UInt(u64), - Float(f64), - Str(&'a str), - Binary(&'a [u8]), -} - -impl<'a> Literal for PythonLiteral<'a> { - fn lit(&self) -> logical_plan::Expr { - match self { - PythonLiteral::Boolean(val) => val.lit(), - PythonLiteral::Int(val) => val.lit(), - PythonLiteral::UInt(val) => val.lit(), - PythonLiteral::Float(val) => val.lit(), - PythonLiteral::Str(val) => val.lit(), - PythonLiteral::Binary(val) => val.lit(), - } - } -} +use datafusion::physical_plan::{ + aggregates::AggregateFunction, functions::BuiltinScalarFunction, +}; -/// Expression representing a constant value -#[pyfunction] -#[pyo3(text_signature = "(value)")] -fn lit(value: &PyAny) -> PyResult { - let py_lit = value.extract::()?; - let expr = py_lit.lit(); - Ok(expression::Expression { expr }) -} +use crate::expression::PyExpr; #[pyfunction] -fn array(value: Vec) -> expression::Expression { - expression::Expression { +fn array(value: Vec) -> PyExpr { + PyExpr { expr: logical_plan::array(value.into_iter().map(|x| x.expr).collect::>()), } } #[pyfunction] -fn in_list( - expr: expression::Expression, - value: Vec, - negated: bool, -) -> expression::Expression { - expression::Expression { - expr: logical_plan::in_list( - expr.expr, - value.into_iter().map(|x| x.expr).collect::>(), - negated, - ), - } +fn in_list(expr: PyExpr, value: Vec, negated: bool) -> PyExpr { + logical_plan::in_list( + expr.expr, + value.into_iter().map(|x| x.expr).collect::>(), + negated, + ) + .into() } /// Current date and time #[pyfunction] -fn now() -> expression::Expression { - expression::Expression { +fn now() -> PyExpr { + PyExpr { // here lit(0) is a stub for conform to arity expr: logical_plan::now(logical_plan::lit(0)), } @@ -103,8 +53,8 @@ fn now() -> expression::Expression { /// Returns a random value in the range 0.0 <= x < 1.0 #[pyfunction] -fn random() -> expression::Expression { - expression::Expression { +fn random() -> PyExpr { + PyExpr { expr: logical_plan::random(), } } @@ -112,11 +62,8 @@ fn random() -> expression::Expression { /// Computes a binary hash of the given data. type is the algorithm to use. /// Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3. #[pyfunction(value, method)] -fn digest( - value: expression::Expression, - method: expression::Expression, -) -> expression::Expression { - expression::Expression { +fn digest(value: PyExpr, method: PyExpr) -> PyExpr { + PyExpr { expr: logical_plan::digest(value.expr, method.expr), } } @@ -124,285 +71,212 @@ fn digest( /// Concatenates the text representations of all the arguments. /// NULL arguments are ignored. #[pyfunction(args = "*")] -fn concat(args: &PyTuple) -> PyResult { - let expressions = expression::from_tuple(args)?; - let args = expressions.into_iter().map(|e| e.expr).collect::>(); - Ok(expression::Expression { - expr: logical_plan::concat(&args), - }) +fn concat(args: Vec) -> PyResult { + let args = args.into_iter().map(|e| e.expr).collect::>(); + Ok(logical_plan::concat(&args).into()) } /// Concatenates all but the first argument, with separators. /// The first argument is used as the separator string, and should not be NULL. /// Other NULL arguments are ignored. #[pyfunction(sep, args = "*")] -fn concat_ws(sep: String, args: &PyTuple) -> PyResult { - let expressions = expression::from_tuple(args)?; - let args = expressions.into_iter().map(|e| e.expr).collect::>(); - Ok(expression::Expression { - expr: logical_plan::concat_ws(sep, &args), - }) +fn concat_ws(sep: String, args: Vec) -> PyResult { + let args = args.into_iter().map(|e| e.expr).collect::>(); + Ok(logical_plan::concat_ws(sep, &args).into()) } -macro_rules! define_unary_function { - ($NAME: ident) => { - #[doc = "This function is not documented yet"] - #[pyfunction] - fn $NAME(value: expression::Expression) -> expression::Expression { - expression::Expression { - expr: logical_plan::$NAME(value.expr), - } - } +macro_rules! scalar_function { + ($NAME: ident, $FUNC: ident) => { + scalar_function!($NAME, $FUNC, stringify!($NAME)); }; - ($NAME: ident, $DOC: expr) => { + ($NAME: ident, $FUNC: ident, $DOC: expr) => { #[doc = $DOC] - #[pyfunction] - fn $NAME(value: expression::Expression) -> expression::Expression { - expression::Expression { - expr: logical_plan::$NAME(value.expr), - } + #[pyfunction(args = "*")] + fn $NAME(args: Vec) -> PyExpr { + let expr = logical_plan::Expr::ScalarFunction { + fun: BuiltinScalarFunction::$FUNC, + args: args.into_iter().map(|e| e.into()).collect(), + }; + expr.into() } }; } -define_unary_function!(sqrt, "sqrt"); -define_unary_function!(sin, "sin"); -define_unary_function!(cos, "cos"); -define_unary_function!(tan, "tan"); -define_unary_function!(asin, "asin"); -define_unary_function!(acos, "acos"); -define_unary_function!(atan, "atan"); -define_unary_function!(floor, "floor"); -define_unary_function!(ceil, "ceil"); -define_unary_function!(round, "round"); -define_unary_function!(trunc, "trunc"); -define_unary_function!(abs, "abs"); -define_unary_function!(signum, "signum"); -define_unary_function!(exp, "exp"); -define_unary_function!(ln, "ln"); -define_unary_function!(log2, "log2"); -define_unary_function!(log10, "log10"); +macro_rules! aggregate_function { + ($NAME: ident, $FUNC: ident) => { + aggregate_function!($NAME, $FUNC, stringify!($NAME)); + }; + ($NAME: ident, $FUNC: ident, $DOC: expr) => { + #[doc = $DOC] + #[pyfunction(args = "*", distinct = "false")] + fn $NAME(args: Vec, distinct: bool) -> PyExpr { + let expr = logical_plan::Expr::AggregateFunction { + fun: AggregateFunction::$FUNC, + args: args.into_iter().map(|e| e.into()).collect(), + distinct, + }; + expr.into() + } + }; +} -define_unary_function!(ascii, "Returns the numeric code of the first character of the argument. In UTF8 encoding, returns the Unicode code point of the character. In other multibyte encodings, the argument must be an ASCII character."); -define_unary_function!(sum); -define_unary_function!( +scalar_function!(abs, Abs); +scalar_function!(acos, Acos); +scalar_function!(ascii, Ascii, "Returns the numeric code of the first character of the argument. In UTF8 encoding, returns the Unicode code point of the character. In other multibyte encodings, the argument must be an ASCII character."); +scalar_function!(asin, Asin); +scalar_function!(atan, Atan); +scalar_function!( bit_length, + BitLength, "Returns number of bits in the string (8 times the octet_length)." ); -define_unary_function!(btrim, "Removes the longest string containing only characters in characters (a space by default) from the start and end of string."); -define_unary_function!( +scalar_function!(btrim, Btrim, "Removes the longest string containing only characters in characters (a space by default) from the start and end of string."); +scalar_function!(ceil, Ceil); +scalar_function!( character_length, + CharacterLength, "Returns number of characters in the string." ); -define_unary_function!(chr, "Returns the character with the given code."); -define_unary_function!(initcap, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters."); -define_unary_function!(left, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters."); -define_unary_function!(lower, "Converts the string to all lower case"); -define_unary_function!(lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right)."); -define_unary_function!(ltrim, "Removes the longest string containing only characters in characters (a space by default) from the start of string."); -define_unary_function!( +scalar_function!(chr, Chr, "Returns the character with the given code."); +scalar_function!(cos, Cos); +scalar_function!(exp, Exp); +scalar_function!(floor, Floor); +scalar_function!(initcap, InitCap, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters."); +scalar_function!(left, Left, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters."); +scalar_function!(ln, Ln); +scalar_function!(log10, Log10); +scalar_function!(log2, Log2); +scalar_function!(lower, Lower, "Converts the string to all lower case"); +scalar_function!(lpad, Lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right)."); +scalar_function!(ltrim, Ltrim, "Removes the longest string containing only characters in characters (a space by default) from the start of string."); +scalar_function!( md5, + MD5, "Computes the MD5 hash of the argument, with the result written in hexadecimal." ); -define_unary_function!(octet_length, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces."); -define_unary_function!( - replace, - "Replaces all occurrences in string of substring from with substring to." -); -define_unary_function!(repeat, "Repeats string the specified number of times."); -define_unary_function!( +scalar_function!(octet_length, OctetLength, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces."); +scalar_function!(regexp_match, RegexpMatch); +scalar_function!( regexp_replace, + RegexpReplace, "Replaces substring(s) matching a POSIX regular expression" ); -define_unary_function!( +scalar_function!( + repeat, + Repeat, + "Repeats string the specified number of times." +); +scalar_function!( + replace, + Replace, + "Replaces all occurrences in string of substring from with substring to." +); +scalar_function!( reverse, + Reverse, "Reverses the order of the characters in the string." ); -define_unary_function!(right, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters."); -define_unary_function!(rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated."); -define_unary_function!(rtrim, "Removes the longest string containing only characters in characters (a space by default) from the end of string."); -define_unary_function!(sha224); -define_unary_function!(sha256); -define_unary_function!(sha384); -define_unary_function!(sha512); -define_unary_function!(split_part, "Splits string at occurrences of delimiter and returns the n'th field (counting from one)."); -define_unary_function!(starts_with, "Returns true if string starts with prefix."); -define_unary_function!(strpos,"Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)"); -define_unary_function!(substr); -define_unary_function!( +scalar_function!(right, Right, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters."); +scalar_function!(round, Round); +scalar_function!(rpad, Rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated."); +scalar_function!(rtrim, Rtrim, "Removes the longest string containing only characters in characters (a space by default) from the end of string."); +scalar_function!(sha224, SHA224); +scalar_function!(sha256, SHA256); +scalar_function!(sha384, SHA384); +scalar_function!(sha512, SHA512); +scalar_function!(signum, Signum); +scalar_function!(sin, Sin); +scalar_function!(split_part, SplitPart, "Splits string at occurrences of delimiter and returns the n'th field (counting from one)."); +scalar_function!(sqrt, Sqrt); +scalar_function!( + starts_with, + StartsWith, + "Returns true if string starts with prefix." +); +scalar_function!(strpos, Strpos, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)"); +scalar_function!(substr, Substr); +scalar_function!(tan, Tan); +scalar_function!( to_hex, + ToHex, "Converts the number to its equivalent hexadecimal representation." ); -define_unary_function!(translate, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); -define_unary_function!(trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string."); -define_unary_function!(upper, "Converts the string to all upper case."); -define_unary_function!(avg); -define_unary_function!(min); -define_unary_function!(max); -define_unary_function!(count); -define_unary_function!(approx_distinct); - -#[pyclass(name = "Volatility", module = "datafusion.functions")] -#[derive(Clone)] -pub struct PyVolatility { - pub(crate) volatility: Volatility, -} - -#[pymethods] -impl PyVolatility { - #[staticmethod] - fn immutable() -> Self { - Self { - volatility: Volatility::Immutable, - } - } - #[staticmethod] - fn stable() -> Self { - Self { - volatility: Volatility::Stable, - } - } - #[staticmethod] - fn volatile() -> Self { - Self { - volatility: Volatility::Volatile, - } - } -} - -pub(crate) fn create_udf( - fun: PyObject, - input_types: Vec, - return_type: PyDataType, - volatility: PyVolatility, - name: &str, -) -> expression::ScalarUDF { - let input_types: Vec = - input_types.iter().map(|d| d.data_type.clone()).collect(); - let return_type = Arc::new(return_type.data_type); - - expression::ScalarUDF { - function: logical_plan::create_udf( - name, - input_types, - return_type, - volatility.volatility, - udf::array_udf(fun), - ), - } -} - -/// Creates a new UDF (User Defined Function). -#[pyfunction] -fn udf( - fun: PyObject, - input_types: Vec, - return_type: PyDataType, - volatility: PyVolatility, - py: Python, -) -> PyResult { - let name = fun.getattr(py, "__qualname__")?.extract::(py)?; - - Ok(create_udf(fun, input_types, return_type, volatility, &name)) -} - -/// Creates a new UDAF (User Defined Aggregate Function). -#[pyfunction] -fn udaf( - accumulator: PyObject, - input_type: PyDataType, - return_type: PyDataType, - state_type: Vec, - volatility: PyVolatility, - py: Python, -) -> PyResult { - let name = accumulator - .getattr(py, "__qualname__")? - .extract::(py)?; - - let input_type = input_type.data_type; - let return_type = Arc::new(return_type.data_type); - let state_type = Arc::new(state_type.into_iter().map(|t| t.data_type).collect()); - - Ok(expression::AggregateUDF { - function: logical_plan::create_udaf( - &name, - input_type, - return_type, - volatility.volatility, - udaf::array_udaf(accumulator), - state_type, - ), - }) -} +scalar_function!(to_timestamp, ToTimestamp); +scalar_function!(translate, Translate, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); +scalar_function!(trim, Trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string."); +scalar_function!(trunc, Trunc); +scalar_function!(upper, Upper, "Converts the string to all upper case."); -pub fn init(module: &PyModule) -> PyResult<()> { - module.add_class::()?; - module.add_function(wrap_pyfunction!(abs, module)?)?; - module.add_function(wrap_pyfunction!(acos, module)?)?; - module.add_function(wrap_pyfunction!(approx_distinct, module)?)?; - module.add_function(wrap_pyfunction!(array, module)?)?; - module.add_function(wrap_pyfunction!(ascii, module)?)?; - module.add_function(wrap_pyfunction!(asin, module)?)?; - module.add_function(wrap_pyfunction!(atan, module)?)?; - module.add_function(wrap_pyfunction!(avg, module)?)?; - module.add_function(wrap_pyfunction!(bit_length, module)?)?; - module.add_function(wrap_pyfunction!(btrim, module)?)?; - module.add_function(wrap_pyfunction!(ceil, module)?)?; - module.add_function(wrap_pyfunction!(character_length, module)?)?; - module.add_function(wrap_pyfunction!(chr, module)?)?; - module.add_function(wrap_pyfunction!(col, module)?)?; - module.add_function(wrap_pyfunction!(concat_ws, module)?)?; - module.add_function(wrap_pyfunction!(concat, module)?)?; - module.add_function(wrap_pyfunction!(cos, module)?)?; - module.add_function(wrap_pyfunction!(count, module)?)?; - module.add_function(wrap_pyfunction!(exp, module)?)?; - module.add_function(wrap_pyfunction!(floor, module)?)?; - module.add_function(wrap_pyfunction!(in_list, module)?)?; - module.add_function(wrap_pyfunction!(initcap, module)?)?; - module.add_function(wrap_pyfunction!(left, module)?)?; - module.add_function(wrap_pyfunction!(lit, module)?)?; - module.add_function(wrap_pyfunction!(ln, module)?)?; - module.add_function(wrap_pyfunction!(log10, module)?)?; - module.add_function(wrap_pyfunction!(log2, module)?)?; - module.add_function(wrap_pyfunction!(lower, module)?)?; - module.add_function(wrap_pyfunction!(lpad, module)?)?; - module.add_function(wrap_pyfunction!(ltrim, module)?)?; - module.add_function(wrap_pyfunction!(max, module)?)?; - module.add_function(wrap_pyfunction!(md5, module)?)?; - module.add_function(wrap_pyfunction!(digest, module)?)?; - module.add_function(wrap_pyfunction!(min, module)?)?; - module.add_function(wrap_pyfunction!(now, module)?)?; - module.add_function(wrap_pyfunction!(octet_length, module)?)?; - module.add_function(wrap_pyfunction!(random, module)?)?; - module.add_function(wrap_pyfunction!(regexp_replace, module)?)?; - module.add_function(wrap_pyfunction!(repeat, module)?)?; - module.add_function(wrap_pyfunction!(replace, module)?)?; - module.add_function(wrap_pyfunction!(reverse, module)?)?; - module.add_function(wrap_pyfunction!(right, module)?)?; - module.add_function(wrap_pyfunction!(round, module)?)?; - module.add_function(wrap_pyfunction!(rpad, module)?)?; - module.add_function(wrap_pyfunction!(rtrim, module)?)?; - module.add_function(wrap_pyfunction!(sha224, module)?)?; - module.add_function(wrap_pyfunction!(sha256, module)?)?; - module.add_function(wrap_pyfunction!(sha384, module)?)?; - module.add_function(wrap_pyfunction!(sha512, module)?)?; - module.add_function(wrap_pyfunction!(signum, module)?)?; - module.add_function(wrap_pyfunction!(sin, module)?)?; - module.add_function(wrap_pyfunction!(split_part, module)?)?; - module.add_function(wrap_pyfunction!(sqrt, module)?)?; - module.add_function(wrap_pyfunction!(starts_with, module)?)?; - module.add_function(wrap_pyfunction!(strpos, module)?)?; - module.add_function(wrap_pyfunction!(substr, module)?)?; - module.add_function(wrap_pyfunction!(sum, module)?)?; - module.add_function(wrap_pyfunction!(tan, module)?)?; - module.add_function(wrap_pyfunction!(to_hex, module)?)?; - module.add_function(wrap_pyfunction!(translate, module)?)?; - module.add_function(wrap_pyfunction!(trim, module)?)?; - module.add_function(wrap_pyfunction!(trunc, module)?)?; - module.add_function(wrap_pyfunction!(udaf, module)?)?; - module.add_function(wrap_pyfunction!(udf, module)?)?; - module.add_function(wrap_pyfunction!(upper, module)?)?; +aggregate_function!(avg, Avg); +aggregate_function!(count, Count); +aggregate_function!(max, Max); +aggregate_function!(min, Min); +aggregate_function!(sum, Sum); +aggregate_function!(approx_distinct, ApproxDistinct); +pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { + m.add_wrapped(wrap_pyfunction!(abs))?; + m.add_wrapped(wrap_pyfunction!(acos))?; + m.add_wrapped(wrap_pyfunction!(approx_distinct))?; + m.add_wrapped(wrap_pyfunction!(array))?; + m.add_wrapped(wrap_pyfunction!(ascii))?; + m.add_wrapped(wrap_pyfunction!(asin))?; + m.add_wrapped(wrap_pyfunction!(atan))?; + m.add_wrapped(wrap_pyfunction!(avg))?; + m.add_wrapped(wrap_pyfunction!(bit_length))?; + m.add_wrapped(wrap_pyfunction!(btrim))?; + m.add_wrapped(wrap_pyfunction!(ceil))?; + m.add_wrapped(wrap_pyfunction!(character_length))?; + m.add_wrapped(wrap_pyfunction!(chr))?; + m.add_wrapped(wrap_pyfunction!(concat_ws))?; + m.add_wrapped(wrap_pyfunction!(concat))?; + m.add_wrapped(wrap_pyfunction!(cos))?; + m.add_wrapped(wrap_pyfunction!(count))?; + m.add_wrapped(wrap_pyfunction!(digest))?; + m.add_wrapped(wrap_pyfunction!(exp))?; + m.add_wrapped(wrap_pyfunction!(floor))?; + m.add_wrapped(wrap_pyfunction!(in_list))?; + m.add_wrapped(wrap_pyfunction!(initcap))?; + m.add_wrapped(wrap_pyfunction!(left))?; + m.add_wrapped(wrap_pyfunction!(ln))?; + m.add_wrapped(wrap_pyfunction!(log10))?; + m.add_wrapped(wrap_pyfunction!(log2))?; + m.add_wrapped(wrap_pyfunction!(lower))?; + m.add_wrapped(wrap_pyfunction!(lpad))?; + m.add_wrapped(wrap_pyfunction!(ltrim))?; + m.add_wrapped(wrap_pyfunction!(max))?; + m.add_wrapped(wrap_pyfunction!(md5))?; + m.add_wrapped(wrap_pyfunction!(min))?; + m.add_wrapped(wrap_pyfunction!(now))?; + m.add_wrapped(wrap_pyfunction!(octet_length))?; + m.add_wrapped(wrap_pyfunction!(random))?; + m.add_wrapped(wrap_pyfunction!(regexp_match))?; + m.add_wrapped(wrap_pyfunction!(regexp_replace))?; + m.add_wrapped(wrap_pyfunction!(repeat))?; + m.add_wrapped(wrap_pyfunction!(replace))?; + m.add_wrapped(wrap_pyfunction!(reverse))?; + m.add_wrapped(wrap_pyfunction!(right))?; + m.add_wrapped(wrap_pyfunction!(round))?; + m.add_wrapped(wrap_pyfunction!(rpad))?; + m.add_wrapped(wrap_pyfunction!(rtrim))?; + m.add_wrapped(wrap_pyfunction!(sha224))?; + m.add_wrapped(wrap_pyfunction!(sha256))?; + m.add_wrapped(wrap_pyfunction!(sha384))?; + m.add_wrapped(wrap_pyfunction!(sha512))?; + m.add_wrapped(wrap_pyfunction!(signum))?; + m.add_wrapped(wrap_pyfunction!(sin))?; + m.add_wrapped(wrap_pyfunction!(split_part))?; + m.add_wrapped(wrap_pyfunction!(sqrt))?; + m.add_wrapped(wrap_pyfunction!(starts_with))?; + m.add_wrapped(wrap_pyfunction!(strpos))?; + m.add_wrapped(wrap_pyfunction!(substr))?; + m.add_wrapped(wrap_pyfunction!(sum))?; + m.add_wrapped(wrap_pyfunction!(tan))?; + m.add_wrapped(wrap_pyfunction!(to_hex))?; + m.add_wrapped(wrap_pyfunction!(to_timestamp))?; + m.add_wrapped(wrap_pyfunction!(translate))?; + m.add_wrapped(wrap_pyfunction!(trim))?; + m.add_wrapped(wrap_pyfunction!(trunc))?; + m.add_wrapped(wrap_pyfunction!(upper))?; Ok(()) } diff --git a/python/src/lib.rs b/python/src/lib.rs index 4436781bec36..d40bae251c86 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -17,42 +17,36 @@ use pyo3::prelude::*; +mod catalog; mod context; mod dataframe; mod errors; mod expression; mod functions; -mod scalar; -mod to_py; -mod to_rust; -mod types; mod udaf; mod udf; +mod utils; -// taken from https://github.com/PyO3/pyo3/issues/471 -fn register_module_package(py: Python, package_name: &str, module: &PyModule) { - py.import("sys") - .expect("failed to import python sys module") - .dict() - .get_item("modules") - .expect("failed to get python modules dictionary") - .downcast::() - .expect("failed to turn sys.modules into a PyDict") - .set_item(package_name, module) - .expect("failed to inject module"); -} - -/// DataFusion. +/// Low-level DataFusion internal package. +/// +/// The higher-level public API is defined in pure python files under the +/// datafusion directory. #[pymodule] -fn datafusion(py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; +fn _internal(py: Python, m: &PyModule) -> PyResult<()> { + // Register the python classes + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; - let functions = PyModule::new(py, "functions")?; - functions::init(functions)?; - register_module_package(py, "datafusion.functions", functions); - m.add_submodule(functions)?; + // Register the functions as a submodule + let funcs = PyModule::new(py, "functions")?; + functions::init_module(funcs)?; + m.add_submodule(funcs)?; Ok(()) } diff --git a/python/src/scalar.rs b/python/src/scalar.rs deleted file mode 100644 index 0c562a940361..000000000000 --- a/python/src/scalar.rs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use pyo3::prelude::*; - -use datafusion::scalar::ScalarValue as _Scalar; - -use crate::to_rust::to_rust_scalar; - -/// An expression that can be used on a DataFrame -#[derive(Debug, Clone)] -pub(crate) struct Scalar { - pub(crate) scalar: _Scalar, -} - -impl<'source> FromPyObject<'source> for Scalar { - fn extract(ob: &'source PyAny) -> PyResult { - Ok(Self { - scalar: to_rust_scalar(ob)?, - }) - } -} diff --git a/python/src/to_py.rs b/python/src/to_py.rs deleted file mode 100644 index 6bc0581c8c70..000000000000 --- a/python/src/to_py.rs +++ /dev/null @@ -1,75 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::array::ArrayRef; -use datafusion::arrow::record_batch::RecordBatch; -use libc::uintptr_t; -use pyo3::prelude::*; -use pyo3::types::PyList; -use pyo3::PyErr; -use std::convert::From; - -use crate::errors; - -pub fn to_py_array(array: &ArrayRef, py: Python) -> PyResult { - let (array_pointer, schema_pointer) = - array.to_raw().map_err(errors::DataFusionError::from)?; - - let pa = py.import("pyarrow")?; - - let array = pa.getattr("Array")?.call_method1( - "_import_from_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - Ok(array.to_object(py)) -} - -fn to_py_batch<'a>( - batch: &RecordBatch, - py: Python, - pyarrow: &'a PyModule, -) -> Result { - let mut py_arrays = vec![]; - let mut py_names = vec![]; - - let schema = batch.schema(); - for (array, field) in batch.columns().iter().zip(schema.fields().iter()) { - let array = to_py_array(array, py)?; - - py_arrays.push(array); - py_names.push(field.name()); - } - - let record = pyarrow - .getattr("RecordBatch")? - .call_method1("from_arrays", (py_arrays, py_names))?; - - Ok(PyObject::from(record)) -} - -/// Converts a &[RecordBatch] into a Vec represented in PyArrow -pub fn to_py(batches: &[RecordBatch]) -> PyResult { - Python::with_gil(|py| { - let pyarrow = PyModule::import(py, "pyarrow")?; - let mut py_batches = vec![]; - for batch in batches { - py_batches.push(to_py_batch(batch, py, pyarrow)?); - } - let list = PyList::new(py, py_batches); - Ok(PyObject::from(list)) - }) -} diff --git a/python/src/to_rust.rs b/python/src/to_rust.rs deleted file mode 100644 index 7977fe4ff8ce..000000000000 --- a/python/src/to_rust.rs +++ /dev/null @@ -1,122 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::convert::TryFrom; -use std::sync::Arc; - -use datafusion::arrow::{ - array::{make_array_from_raw, ArrayRef}, - datatypes::Field, - datatypes::Schema, - ffi, - record_batch::RecordBatch, -}; -use datafusion::scalar::ScalarValue; -use libc::uintptr_t; -use pyo3::prelude::*; - -use crate::{errors, types::PyDataType}; - -/// converts a pyarrow Array into a Rust Array -pub fn to_rust(ob: &PyAny) -> PyResult { - // prepare a pointer to receive the Array struct - let (array_pointer, schema_pointer) = - ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); - - // make the conversion through PyArrow's private API - // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds - ob.call_method1( - "_export_to_c", - (array_pointer as uintptr_t, schema_pointer as uintptr_t), - )?; - - let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } - .map_err(errors::DataFusionError::from)?; - Ok(array) -} - -/// converts a pyarrow batch into a RecordBatch -pub fn to_rust_batch(batch: &PyAny) -> PyResult { - let schema = batch.getattr("schema")?; - let names = schema.getattr("names")?.extract::>()?; - - let fields = names - .iter() - .enumerate() - .map(|(i, name)| { - let field = schema.call_method1("field", (i,))?; - let nullable = field.getattr("nullable")?.extract::()?; - let py_data_type = field.getattr("type")?; - let data_type = py_data_type.extract::()?.data_type; - Ok(Field::new(name, data_type, nullable)) - }) - .collect::>()?; - - let schema = Arc::new(Schema::new(fields)); - - let arrays = (0..names.len()) - .map(|i| { - let array = batch.call_method1("column", (i,))?; - to_rust(array) - }) - .collect::>()?; - - let batch = - RecordBatch::try_new(schema, arrays).map_err(errors::DataFusionError::from)?; - Ok(batch) -} - -/// converts a pyarrow Scalar into a Rust Scalar -pub fn to_rust_scalar(ob: &PyAny) -> PyResult { - let t = ob - .getattr("__class__")? - .getattr("__name__")? - .extract::<&str>()?; - - let p = ob.call_method0("as_py")?; - - Ok(match t { - "Int8Scalar" => ScalarValue::Int8(Some(p.extract::()?)), - "Int16Scalar" => ScalarValue::Int16(Some(p.extract::()?)), - "Int32Scalar" => ScalarValue::Int32(Some(p.extract::()?)), - "Int64Scalar" => ScalarValue::Int64(Some(p.extract::()?)), - "UInt8Scalar" => ScalarValue::UInt8(Some(p.extract::()?)), - "UInt16Scalar" => ScalarValue::UInt16(Some(p.extract::()?)), - "UInt32Scalar" => ScalarValue::UInt32(Some(p.extract::()?)), - "UInt64Scalar" => ScalarValue::UInt64(Some(p.extract::()?)), - "FloatScalar" => ScalarValue::Float32(Some(p.extract::()?)), - "DoubleScalar" => ScalarValue::Float64(Some(p.extract::()?)), - "BooleanScalar" => ScalarValue::Boolean(Some(p.extract::()?)), - "StringScalar" => ScalarValue::Utf8(Some(p.extract::()?)), - "LargeStringScalar" => ScalarValue::LargeUtf8(Some(p.extract::()?)), - other => { - return Err(errors::DataFusionError::Common(format!( - "Type \"{}\"not yet implemented", - other - )) - .into()) - } - }) -} - -pub fn to_rust_schema(ob: &PyAny) -> PyResult { - let c_schema = ffi::FFI_ArrowSchema::empty(); - let c_schema_ptr = &c_schema as *const ffi::FFI_ArrowSchema; - ob.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; - let schema = Schema::try_from(&c_schema).map_err(errors::DataFusionError::from)?; - Ok(schema) -} diff --git a/python/src/types.rs b/python/src/types.rs deleted file mode 100644 index bd6ef0d376e6..000000000000 --- a/python/src/types.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::datatypes::DataType; -use pyo3::{FromPyObject, PyAny, PyResult}; - -use crate::errors; - -/// utility struct to convert PyObj to native DataType -#[derive(Debug, Clone)] -pub struct PyDataType { - pub data_type: DataType, -} - -impl<'source> FromPyObject<'source> for PyDataType { - fn extract(ob: &'source PyAny) -> PyResult { - let id = ob.getattr("id")?.extract::()?; - let data_type = data_type_id(&id)?; - Ok(PyDataType { data_type }) - } -} - -fn data_type_id(id: &i32) -> Result { - // see https://github.com/apache/arrow/blob/3694794bdfd0677b95b8c95681e392512f1c9237/python/pyarrow/includes/libarrow.pxd - // this is not ideal as it does not generalize for non-basic types - // Find a way to get a unique name from the pyarrow.DataType - Ok(match id { - 1 => DataType::Boolean, - 2 => DataType::UInt8, - 3 => DataType::Int8, - 4 => DataType::UInt16, - 5 => DataType::Int16, - 6 => DataType::UInt32, - 7 => DataType::Int32, - 8 => DataType::UInt64, - 9 => DataType::Int64, - 10 => DataType::Float16, - 11 => DataType::Float32, - 12 => DataType::Float64, - 13 => DataType::Utf8, - 14 => DataType::Binary, - 34 => DataType::LargeUtf8, - 35 => DataType::LargeBinary, - other => { - return Err(errors::DataFusionError::Common(format!( - "The type {} is not valid", - other - ))) - } - }) -} diff --git a/python/src/udaf.rs b/python/src/udaf.rs index 83e8be05db60..1de6e63205ed 100644 --- a/python/src/udaf.rs +++ b/python/src/udaf.rs @@ -20,41 +20,33 @@ use std::sync::Arc; use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::ArrayRef; - -use datafusion::error::Result; -use datafusion::{ - error::DataFusionError as InnerDataFusionError, physical_plan::Accumulator, - scalar::ScalarValue, -}; - -use crate::scalar::Scalar; -use crate::to_py::to_py_array; -use crate::to_rust::to_rust_scalar; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::pyarrow::PyArrowConvert; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_plan; +use datafusion::physical_plan::aggregates::AccumulatorFunctionImplementation; +use datafusion::physical_plan::udaf::AggregateUDF; +use datafusion::physical_plan::Accumulator; +use datafusion::scalar::ScalarValue; + +use crate::expression::PyExpr; +use crate::utils::parse_volatility; #[derive(Debug)] -struct PyAccumulator { +struct RustAccumulator { accum: PyObject, } -impl PyAccumulator { +impl RustAccumulator { fn new(accum: PyObject) -> Self { Self { accum } } } -impl Accumulator for PyAccumulator { - fn state(&self) -> Result> { - Python::with_gil(|py| { - let state = self - .accum - .as_ref(py) - .call_method0("to_scalars") - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))? - .extract::>() - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - Ok(state.into_iter().map(|v| v.scalar).collect::>()) - }) +impl Accumulator for RustAccumulator { + fn state(&self) -> Result> { + Python::with_gil(|py| self.accum.as_ref(py).call_method0("state")?.extract()) + .map_err(|e| DataFusionError::Execution(format!("{}", e))) } fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { @@ -67,39 +59,25 @@ impl Accumulator for PyAccumulator { todo!() } - fn evaluate(&self) -> Result { - Python::with_gil(|py| { - let value = self - .accum - .as_ref(py) - .call_method0("evaluate") - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; - - to_rust_scalar(value) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) - }) + fn evaluate(&self) -> Result { + Python::with_gil(|py| self.accum.as_ref(py).call_method0("evaluate")?.extract()) + .map_err(|e| DataFusionError::Execution(format!("{}", e))) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { Python::with_gil(|py| { // 1. cast args to Pyarrow array - // 2. call function - - // 1. let py_args = values .iter() - .map(|arg| { - // remove unwrap - to_py_array(arg, py).unwrap() - }) + .map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args); - // update accumulator + // 2. call function self.accum .as_ref(py) .call_method1("update", py_args) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; Ok(()) }) @@ -107,33 +85,69 @@ impl Accumulator for PyAccumulator { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { Python::with_gil(|py| { - // 1. cast states to Pyarrow array - // 2. merge let state = &states[0]; - let state = to_py_array(state, py) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + // 1. cast states to Pyarrow array + let state = state + .to_pyarrow(py) + .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; - // 2. + // 2. call merge self.accum .as_ref(py) .call_method1("merge", (state,)) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?; + .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; Ok(()) }) } } -pub fn array_udaf( - accumulator: PyObject, -) -> Arc Result> + Send + Sync> { +pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFunctionImplementation { Arc::new(move || -> Result> { - let accumulator = Python::with_gil(|py| { - accumulator + let accum = Python::with_gil(|py| { + accum .call0(py) - .map_err(|e| InnerDataFusionError::Execution(format!("{}", e))) + .map_err(|e| DataFusionError::Execution(format!("{}", e))) })?; - Ok(Box::new(PyAccumulator::new(accumulator))) + Ok(Box::new(RustAccumulator::new(accum))) }) } + +/// Represents a AggregateUDF +#[pyclass(name = "AggregateUDF", module = "datafusion", subclass)] +#[derive(Debug, Clone)] +pub struct PyAggregateUDF { + pub(crate) function: AggregateUDF, +} + +#[pymethods] +impl PyAggregateUDF { + #[new(name, accumulator, input_type, return_type, state_type, volatility)] + fn new( + name: &str, + accumulator: PyObject, + input_type: DataType, + return_type: DataType, + state_type: Vec, + volatility: &str, + ) -> PyResult { + let function = logical_plan::create_udaf( + &name, + input_type, + Arc::new(return_type), + parse_volatility(volatility)?, + to_rust_accumulator(accumulator), + Arc::new(state_type), + ); + Ok(Self { function }) + } + + /// creates a new PyExpr with the call of the udf + #[call] + #[args(args = "*")] + fn __call__(&self, args: Vec) -> PyResult { + let args = args.iter().map(|e| e.expr.clone()).collect(); + Ok(self.function.call(args).into()) + } +} diff --git a/python/src/udf.rs b/python/src/udf.rs index 49a18d993241..379c449870b2 100644 --- a/python/src/udf.rs +++ b/python/src/udf.rs @@ -15,46 +15,84 @@ // specific language governing permissions and limitations // under the License. -use pyo3::{prelude::*, types::PyTuple}; +use std::sync::Arc; -use datafusion::{arrow::array, physical_plan::functions::make_scalar_function}; +use pyo3::{prelude::*, types::PyTuple}; +use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::pyarrow::PyArrowConvert; use datafusion::error::DataFusionError; -use datafusion::physical_plan::functions::ScalarFunctionImplementation; +use datafusion::logical_plan; +use datafusion::physical_plan::functions::{ + make_scalar_function, ScalarFunctionImplementation, +}; +use datafusion::physical_plan::udf::ScalarUDF; -use crate::to_py::to_py_array; -use crate::to_rust::to_rust; +use crate::expression::PyExpr; +use crate::utils::parse_volatility; -/// creates a DataFusion's UDF implementation from a python function that expects pyarrow arrays -/// This is more efficient as it performs a zero-copy of the contents. -pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation { +/// Create a DataFusion's UDF implementation from a python function +/// that expects pyarrow arrays. This is more efficient as it performs +/// a zero-copy of the contents. +fn to_rust_function(func: PyObject) -> ScalarFunctionImplementation { make_scalar_function( - move |args: &[array::ArrayRef]| -> Result { + move |args: &[ArrayRef]| -> Result { Python::with_gil(|py| { // 1. cast args to Pyarrow arrays - // 2. call function - // 3. cast to arrow::array::Array - - // 1. let py_args = args .iter() - .map(|arg| { - // remove unwrap - to_py_array(arg, py).unwrap() - }) + .map(|arg| arg.data().to_owned().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args); - // 2. + // 2. call function let value = func.as_ref(py).call(py_args, None); let value = match value { Ok(n) => Ok(n), Err(error) => Err(DataFusionError::Execution(format!("{:?}", error))), }?; - let array = to_rust(value).unwrap(); + // 3. cast to arrow::array::Array + let array = ArrayRef::from_pyarrow(value).unwrap(); Ok(array) }) }, ) } + +/// Represents a PyScalarUDF +#[pyclass(name = "ScalarUDF", module = "datafusion", subclass)] +#[derive(Debug, Clone)] +pub struct PyScalarUDF { + pub(crate) function: ScalarUDF, +} + +#[pymethods] +impl PyScalarUDF { + #[new(name, func, input_types, return_type, volatility)] + fn new( + name: &str, + func: PyObject, + input_types: Vec, + return_type: DataType, + volatility: &str, + ) -> PyResult { + let function = logical_plan::create_udf( + name, + input_types, + Arc::new(return_type), + parse_volatility(volatility)?, + to_rust_function(func), + ); + Ok(Self { function }) + } + + /// creates a new PyExpr with the call of the udf + #[call] + #[args(args = "*")] + fn __call__(&self, args: Vec) -> PyResult { + let args = args.iter().map(|e| e.expr.clone()).collect(); + Ok(self.function.call(args).into()) + } +} diff --git a/python/src/utils.rs b/python/src/utils.rs new file mode 100644 index 000000000000..c8e1c63b1d0f --- /dev/null +++ b/python/src/utils.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; + +use pyo3::prelude::*; +use tokio::runtime::Runtime; + +use datafusion::physical_plan::functions::Volatility; + +use crate::errors::DataFusionError; + +/// Utility to collect rust futures with GIL released +pub(crate) fn wait_for_future(py: Python, f: F) -> F::Output +where + F: Send, + F::Output: Send, +{ + let rt = Runtime::new().unwrap(); + py.allow_threads(|| rt.block_on(f)) +} + +pub(crate) fn parse_volatility(value: &str) -> Result { + Ok(match value { + "immutable" => Volatility::Immutable, + "stable" => Volatility::Stable, + "volatile" => Volatility::Volatile, + value => { + return Err(DataFusionError::Common(format!( + "Unsupportad volatility type: `{}`, supported \ + values are: immutable, stable and volatile.", + value + ))) + } + }) +} diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py deleted file mode 100644 index 67cf502c445e..000000000000 --- a/python/tests/test_functions.py +++ /dev/null @@ -1,63 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pyarrow as pa -import pytest -from datafusion import ExecutionContext -from datafusion import functions as f - - -@pytest.fixture -def df(): - ctx = ExecutionContext() - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) - - -def test_lit(df): - """test lit function""" - df = df.select( - f.lit(1), - f.lit("1"), - f.lit("OK"), - f.lit(3.14), - f.lit(True), - f.lit(b"hello world"), - ) - result = df.collect() - assert len(result) == 1 - result = result[0] - assert result.column(0) == pa.array([1] * 3) - assert result.column(1) == pa.array(["1"] * 3) - assert result.column(2) == pa.array(["OK"] * 3) - assert result.column(3) == pa.array([3.14] * 3) - assert result.column(4) == pa.array([True] * 3) - assert result.column(5) == pa.array([b"hello world"] * 3) - - -def test_lit_arith(df): - """test lit function within arithmatics""" - df = df.select(f.lit(1) + f.col("b"), f.concat(f.col("a"), f.lit("!"))) - result = df.collect() - assert len(result) == 1 - result = result[0] - assert result.column(0) == pa.array([5, 6, 7]) - assert result.column(1) == pa.array(["Hello!", "World!", "!!"]) diff --git a/python/tests/test_math_functions.py b/python/tests/test_math_functions.py deleted file mode 100644 index 98656b8c4f42..000000000000 --- a/python/tests/test_math_functions.py +++ /dev/null @@ -1,70 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import numpy as np -import pyarrow as pa -import pytest -from datafusion import ExecutionContext -from datafusion import functions as f - - -@pytest.fixture -def df(): - ctx = ExecutionContext() - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array([0.1, -0.7, 0.55])], names=["value"] - ) - return ctx.create_dataframe([[batch]]) - - -def test_math_functions(df): - values = np.array([0.1, -0.7, 0.55]) - col_v = f.col("value") - df = df.select( - f.abs(col_v), - f.sin(col_v), - f.cos(col_v), - f.tan(col_v), - f.asin(col_v), - f.acos(col_v), - f.exp(col_v), - f.ln(col_v + f.lit(1)), - f.log2(col_v + f.lit(1)), - f.log10(col_v + f.lit(1)), - f.random(), - ) - result = df.collect() - assert len(result) == 1 - result = result[0] - np.testing.assert_array_almost_equal(result.column(0), np.abs(values)) - np.testing.assert_array_almost_equal(result.column(1), np.sin(values)) - np.testing.assert_array_almost_equal(result.column(2), np.cos(values)) - np.testing.assert_array_almost_equal(result.column(3), np.tan(values)) - np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values)) - np.testing.assert_array_almost_equal(result.column(5), np.arccos(values)) - np.testing.assert_array_almost_equal(result.column(6), np.exp(values)) - np.testing.assert_array_almost_equal( - result.column(7), np.log(values + 1.0) - ) - np.testing.assert_array_almost_equal( - result.column(8), np.log2(values + 1.0) - ) - np.testing.assert_array_almost_equal( - result.column(9), np.log10(values + 1.0) - ) - np.testing.assert_array_less(result.column(10), np.ones_like(values)) diff --git a/python/tests/test_pa_types.py b/python/tests/test_pa_types.py deleted file mode 100644 index 04f6110e3a42..000000000000 --- a/python/tests/test_pa_types.py +++ /dev/null @@ -1,50 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pyarrow as pa - - -def test_type_ids(): - # Having this fixed is very important because internally we rely on this id - # to parse from python - for idx, arrow_type in [ - (0, pa.null()), - (1, pa.bool_()), - (2, pa.uint8()), - (3, pa.int8()), - (4, pa.uint16()), - (5, pa.int16()), - (6, pa.uint32()), - (7, pa.int32()), - (8, pa.uint64()), - (9, pa.int64()), - (10, pa.float16()), - (11, pa.float32()), - (12, pa.float64()), - (13, pa.string()), - (13, pa.utf8()), - (14, pa.binary()), - (16, pa.date32()), - (17, pa.date64()), - (18, pa.timestamp("us")), - (19, pa.time32("s")), - (20, pa.time64("us")), - (23, pa.decimal128(8, 1)), - (34, pa.large_utf8()), - (35, pa.large_binary()), - ]: - assert idx == arrow_type.id diff --git a/python/tests/test_string_functions.py b/python/tests/test_string_functions.py deleted file mode 100644 index 965f08707285..000000000000 --- a/python/tests/test_string_functions.py +++ /dev/null @@ -1,121 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import pyarrow as pa -import pytest -from datafusion import ExecutionContext -from datafusion import functions as f - - -@pytest.fixture -def df(): - ctx = ExecutionContext() - - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - - return ctx.create_dataframe([[batch]]) - - -def test_string_functions(df): - df = df.select(f.md5(f.col("a")), f.lower(f.col("a"))) - result = df.collect() - assert len(result) == 1 - result = result[0] - assert result.column(0) == pa.array( - [ - "8b1a9953c4611296a827abf8c47804d7", - "f5a7924e621e84c9280a9a27e1bcb7f6", - "9033e0e305f247c0c3c80d0c7848c8b3", - ] - ) - assert result.column(1) == pa.array(["hello", "world", "!"]) - - -def test_hash_functions(df): - df = df.select( - *[ - f.digest(f.col("a"), f.lit(m)) - for m in ("md5", "sha256", "sha512", "blake2s", "blake3") - ] - ) - result = df.collect() - assert len(result) == 1 - result = result[0] - b = bytearray.fromhex - assert result.column(0) == pa.array( - [ - b("8B1A9953C4611296A827ABF8C47804D7"), - b("F5A7924E621E84C9280A9A27E1BCB7F6"), - b("9033E0E305F247C0C3C80D0C7848C8B3"), - ] - ) - assert result.column(1) == pa.array( - [ - b( - "185F8DB32271FE25F561A6FC938B2E264306EC304EDA518007D1764826381969" - ), - b( - "78AE647DC5544D227130A0682A51E30BC7777FBB6D8A8F17007463A3ECD1D524" - ), - b( - "BB7208BC9B5D7C04F1236A82A0093A5E33F40423D5BA8D4266F7092C3BA43B62" - ), - ] - ) - assert result.column(2) == pa.array( - [ - b( - "3615F80C9D293ED7402687F94B22D58E529B8CC7916F8FAC7FDDF7FBD5AF4CF777D3D795A7A00A16BF7E7F3FB9561EE9BAAE480DA9FE7A18769E71886B03F315" - ), - b( - "8EA77393A42AB8FA92500FB077A9509CC32BC95E72712EFA116EDAF2EDFAE34FBB682EFDD6C5DD13C117E08BD4AAEF71291D8AACE2F890273081D0677C16DF0F" - ), - b( - "3831A6A6155E509DEE59A7F451EB35324D8F8F2DF6E3708894740F98FDEE23889F4DE5ADB0C5010DFB555CDA77C8AB5DC902094C52DE3278F35A75EBC25F093A" - ), - ] - ) - assert result.column(3) == pa.array( - [ - b( - "F73A5FBF881F89B814871F46E26AD3FA37CB2921C5E8561618639015B3CCBB71" - ), - b( - "B792A0383FB9E7A189EC150686579532854E44B71AC394831DAED169BA85CCC5" - ), - b( - "27988A0E51812297C77A433F635233346AEE29A829DCF4F46E0F58F402C6CFCB" - ), - ] - ) - assert result.column(4) == pa.array( - [ - b( - "FBC2B0516EE8744D293B980779178A3508850FDCFE965985782C39601B65794F" - ), - b( - "BF73D18575A736E4037D45F9E316085B86C19BE6363DE6AA789E13DEAACC1C4E" - ), - b( - "C8D11B9F7237E4034ADBCD2005735F9BC4C597C75AD89F4492BEC8F77D15F7EB" - ), - ] - )