From 134e3b5efe44cf72e32dddad9efe2019d1fc1c9f Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Thu, 8 Apr 2021 12:20:22 +1000 Subject: [PATCH] input_file_name --- rust/arrow/examples/read_csv.rs | 3 +- rust/arrow/src/csv/reader.rs | 45 +- rust/arrow/src/csv/writer.rs | 1 + rust/arrow/src/json/reader.rs | 57 ++- .../examples/simple_udf.rs | 8 +- rust/datafusion/README.md | 2 + rust/datafusion/src/execution/context.rs | 4 +- .../src/execution/dataframe_impl.rs | 6 +- .../src/physical_plan/array_expressions.rs | 6 +- .../src/physical_plan/crypto_expressions.rs | 12 +- rust/datafusion/src/physical_plan/csv.rs | 1 + .../src/physical_plan/datetime_expressions.rs | 12 +- .../src/physical_plan/expressions/nullif.rs | 8 +- .../datafusion/src/physical_plan/functions.rs | 466 +++++++++--------- .../src/physical_plan/math_expressions.rs | 4 +- rust/datafusion/src/physical_plan/parquet.rs | 11 +- .../src/physical_plan/regex_expressions.rs | 11 +- .../src/physical_plan/string_expressions.rs | 70 ++- .../src/physical_plan/unicode_expressions.rs | 35 +- rust/datafusion/src/sql/planner.rs | 2 +- rust/datafusion/tests/sql.rs | 68 ++- 21 files changed, 530 insertions(+), 302 deletions(-) diff --git a/rust/arrow/examples/read_csv.rs b/rust/arrow/examples/read_csv.rs index 9e2b9c34c86a1..a4ed94538cb2c 100644 --- a/rust/arrow/examples/read_csv.rs +++ b/rust/arrow/examples/read_csv.rs @@ -34,7 +34,8 @@ fn main() { let file = File::open("test/data/uk_cities.csv").unwrap(); - let mut csv = csv::Reader::new(file, Arc::new(schema), false, None, 1024, None, None); + let mut csv = + csv::Reader::new(None, file, Arc::new(schema), false, None, 1024, None, None); let _batch = csv.next().unwrap().unwrap(); #[cfg(feature = "prettyprint")] { diff --git a/rust/arrow/src/csv/reader.rs b/rust/arrow/src/csv/reader.rs index 985c88b4978fa..b6532d84cd637 100644 --- a/rust/arrow/src/csv/reader.rs +++ b/rust/arrow/src/csv/reader.rs @@ -36,14 +36,14 @@ //! //! let file = File::open("test/data/uk_cities.csv").unwrap(); //! -//! let mut csv = csv::Reader::new(file, Arc::new(schema), false, None, 1024, None, None); +//! let mut csv = csv::Reader::new(None, file, Arc::new(schema), false, None, 1024, None, None); //! let batch = csv.next().unwrap().unwrap(); //! ``` use core::cmp::min; use lazy_static::lazy_static; use regex::{Regex, RegexBuilder}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::fs::File; use std::io::{Read, Seek, SeekFrom}; @@ -251,6 +251,8 @@ pub struct Reader { projection: Option>, /// File reader reader: csv_crate::Reader, + /// Current file + filename: Option, /// Current line number line_number: usize, /// Maximum number of rows to read @@ -280,7 +282,9 @@ impl Reader { /// If reading a `File` or an input that supports `std::io::Read` and `std::io::Seek`; /// you can customise the Reader, such as to enable schema inference, use /// `ReaderBuilder`. + #[allow(clippy::too_many_arguments)] pub fn new( + filename: Option, reader: R, schema: SchemaRef, has_header: bool, @@ -290,7 +294,8 @@ impl Reader { projection: Option>, ) -> Self { Self::from_reader( - reader, schema, has_header, delimiter, batch_size, bounds, projection, + filename, reader, schema, has_header, delimiter, batch_size, bounds, + projection, ) } @@ -313,7 +318,9 @@ impl Reader { /// /// This constructor allows you more flexibility in what records are processed by the /// csv reader. + #[allow(clippy::too_many_arguments)] pub fn from_reader( + filename: Option, reader: R, schema: SchemaRef, has_header: bool, @@ -359,6 +366,7 @@ impl Reader { schema, projection, reader: csv_reader, + filename, line_number: if has_header { start + 1 } else { start }, batch_size, end, @@ -406,6 +414,19 @@ impl Iterator for Reader { self.line_number += read_records; + let result = result.map(|batch| match self.filename.clone() { + Some(filename) => { + let mut metadata = HashMap::new(); + metadata.insert("filename".to_string(), filename); + let schema = Arc::new(Schema::new_with_metadata( + batch.schema().fields().clone(), + metadata, + )); + RecordBatch::try_new(schema, batch.columns().to_vec()).unwrap() + } + None => batch, + }); + Some(result) } } @@ -670,6 +691,8 @@ fn build_boolean_array( /// CSV file reader builder #[derive(Debug)] pub struct ReaderBuilder { + /// Optional filename + filename: Option, /// Optional schema for the CSV file /// /// If the schema is not supplied, the reader will try to infer the schema @@ -699,6 +722,7 @@ pub struct ReaderBuilder { impl Default for ReaderBuilder { fn default() -> Self { Self { + filename: None, schema: None, has_header: false, delimiter: None, @@ -738,6 +762,12 @@ impl ReaderBuilder { ReaderBuilder::default() } + /// Set the CSV file's schema + pub fn with_file_name(mut self, file_name: String) -> Self { + self.filename = Some(file_name); + self + } + /// Set the CSV file's schema pub fn with_schema(mut self, schema: SchemaRef) -> Self { self.schema = Some(schema); @@ -794,6 +824,7 @@ impl ReaderBuilder { } }; Ok(Reader::from_reader( + self.filename, reader, schema, self.has_header, @@ -827,6 +858,7 @@ mod tests { let file = File::open("test/data/uk_cities.csv").unwrap(); let mut csv = Reader::new( + None, file, Arc::new(schema.clone()), false, @@ -874,6 +906,7 @@ mod tests { let file = File::open("test/data/uk_cities.csv").unwrap(); let mut csv = Reader::new( + None, file, Arc::new(schema.clone()), false, @@ -905,6 +938,7 @@ mod tests { .chain(Cursor::new("\n".to_string())) .chain(file_without_headers); let mut csv = Reader::from_reader( + None, both_files, Arc::new(schema), true, @@ -1002,6 +1036,7 @@ mod tests { let file = File::open("test/data/uk_cities.csv").unwrap(); let mut csv = Reader::new( + None, file, Arc::new(schema), false, @@ -1031,7 +1066,8 @@ mod tests { let file = File::open("test/data/null_test.csv").unwrap(); - let mut csv = Reader::new(file, Arc::new(schema), true, None, 1024, None, None); + let mut csv = + Reader::new(None, file, Arc::new(schema), true, None, 1024, None, None); let batch = csv.next().unwrap().unwrap(); assert_eq!(false, batch.column(1).is_null(0)); @@ -1227,6 +1263,7 @@ mod tests { let reader = std::io::Cursor::new(data); let mut csv = Reader::new( + None, reader, Arc::new(schema), false, diff --git a/rust/arrow/src/csv/writer.rs b/rust/arrow/src/csv/writer.rs index e9d8565b2a5b0..bf6f2e685442f 100644 --- a/rust/arrow/src/csv/writer.rs +++ b/rust/arrow/src/csv/writer.rs @@ -628,6 +628,7 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03\n"; buf.set_position(0); let mut reader = Reader::new( + None, buf, Arc::new(schema), false, diff --git a/rust/arrow/src/json/reader.rs b/rust/arrow/src/json/reader.rs index 31c496c9293bc..9d9ca9a749a67 100644 --- a/rust/arrow/src/json/reader.rs +++ b/rust/arrow/src/json/reader.rs @@ -36,12 +36,14 @@ //! Field::new("c", DataType::Float64, false), //! ]); //! -//! let file = File::open("test/data/basic.json").unwrap(); +//! let filename = "test/data/basic.json"; +//! let file = File::open(filename).unwrap(); //! -//! let mut json = json::Reader::new(BufReader::new(file), Arc::new(schema), 1024, None); +//! let mut json = json::Reader::new(Some(filename.to_string()), BufReader::new(file), Arc::new(schema), 1024, None); //! let batch = json.next().unwrap().unwrap(); //! ``` +use std::collections::HashMap as CollectionsHashMap; use std::io::{BufRead, BufReader, Read, Seek, SeekFrom}; use std::iter::FromIterator; use std::sync::Arc; @@ -559,11 +561,12 @@ where /// use std::io::{BufReader, Seek, SeekFrom}; /// use std::sync::Arc; /// +/// let filename = "test/data/mixed_arrays.json"; /// let mut reader = -/// BufReader::new(File::open("test/data/mixed_arrays.json").unwrap()); +/// BufReader::new(File::open(filename).unwrap()); /// let inferred_schema = infer_json_schema(&mut reader, None).unwrap(); /// let batch_size = 1024; -/// let decoder = Decoder::new(Arc::new(inferred_schema), batch_size, None); +/// let decoder = Decoder::new(Arc::new(inferred_schema), batch_size, None, Some(filename.to_string())); /// /// // seek back to start so that the original file is usable again /// reader.seek(SeekFrom::Start(0)).unwrap(); @@ -580,6 +583,8 @@ pub struct Decoder { projection: Option>, /// Batch size (number of records to load each time) batch_size: usize, + // filename + filename: Option, } impl Decoder { @@ -589,11 +594,13 @@ impl Decoder { schema: SchemaRef, batch_size: usize, projection: Option>, + filename: Option, ) -> Self { Self { schema, projection, batch_size, + filename, } } @@ -661,7 +668,25 @@ impl Decoder { let projected_schema = Arc::new(Schema::new(projected_fields)); - arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr).map(Some)) + arrays.and_then(|arr| { + RecordBatch::try_new(projected_schema, arr).map(|batch| { + match self.filename.clone() { + Some(filename) => { + let mut metadata = CollectionsHashMap::new(); + metadata.insert("filename".to_string(), filename); + let schema = Arc::new(Schema::new_with_metadata( + batch.schema().fields().clone(), + metadata, + )); + Some( + RecordBatch::try_new(schema, batch.columns().to_vec()) + .unwrap(), + ) + } + None => Some(batch), + } + }) + }) } fn build_wrapped_list_array( @@ -1422,18 +1447,26 @@ impl Reader { /// If reading a `File`, you can customise the Reader, such as to enable schema /// inference, use `ReaderBuilder`. pub fn new( + filename: Option, reader: R, schema: SchemaRef, batch_size: usize, projection: Option>, ) -> Self { - Self::from_buf_reader(BufReader::new(reader), schema, batch_size, projection) + Self::from_buf_reader( + filename, + BufReader::new(reader), + schema, + batch_size, + projection, + ) } /// Create a new JSON Reader from a `BufReader` /// /// To customize the schema, such as to enable schema inference, use `ReaderBuilder` pub fn from_buf_reader( + filename: Option, reader: BufReader, schema: SchemaRef, batch_size: usize, @@ -1441,7 +1474,7 @@ impl Reader { ) -> Self { Self { reader, - decoder: Decoder::new(schema, batch_size, projection), + decoder: Decoder::new(schema, batch_size, projection, filename), } } @@ -1561,6 +1594,7 @@ impl ReaderBuilder { }; Ok(Reader::from_buf_reader( + None, buf_reader, schema, self.batch_size, @@ -1708,6 +1742,7 @@ mod tests { ]); let mut reader: Reader = Reader::new( + None, File::open("test/data/basic.json").unwrap(), Arc::new(schema.clone()), 1024, @@ -1760,6 +1795,7 @@ mod tests { ]); let mut reader: Reader = Reader::new( + None, File::open("test/data/basic.json").unwrap(), Arc::new(schema), 1024, @@ -1929,7 +1965,8 @@ mod tests { file.seek(SeekFrom::Start(0)).unwrap(); let reader = BufReader::new(GzDecoder::new(&file)); - let mut reader = Reader::from_buf_reader(reader, Arc::new(schema), 64, None); + let mut reader = + Reader::from_buf_reader(None, reader, Arc::new(schema), 64, None); let batch_gz = reader.next().unwrap().unwrap(); for batch in vec![batch, batch_gz] { @@ -2886,7 +2923,7 @@ mod tests { true, )]); - let decoder = Decoder::new(Arc::new(schema), 1024, None); + let decoder = Decoder::new(Arc::new(schema), 1024, None, None); let batch = decoder .next_batch( &mut vec![ @@ -2921,7 +2958,7 @@ mod tests { true, )]); - let decoder = Decoder::new(Arc::new(schema), 1024, None); + let decoder = Decoder::new(Arc::new(schema), 1024, None, None); let batch = decoder .next_batch( // NOTE: total struct element count needs to be greater than diff --git a/rust/datafusion-examples/examples/simple_udf.rs b/rust/datafusion-examples/examples/simple_udf.rs index bfef1089a634c..634fc42fa1f55 100644 --- a/rust/datafusion-examples/examples/simple_udf.rs +++ b/rust/datafusion-examples/examples/simple_udf.rs @@ -15,20 +15,20 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use arrow::{ array::{ArrayRef, Float32Array, Float64Array}, - datatypes::DataType, + datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, util::pretty, }; use datafusion::prelude::*; use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; -use std::sync::Arc; // create local execution context with an in-memory table fn create_context() -> Result { - use arrow::datatypes::{Field, Schema}; use datafusion::datasource::MemTable; // define a schema. let schema = Arc::new(Schema::new(vec![ @@ -60,7 +60,7 @@ async fn main() -> Result<()> { let mut ctx = create_context()?; // First, declare the actual implementation of the calculation - let pow = |args: &[ArrayRef]| { + let pow = |args: &[ArrayRef], _: &Schema| { // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: // 1. cast the values to the type we want // 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index d0bb7d38892f2..d46ac76f9300c 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -172,11 +172,13 @@ DataFusion also includes a simple command-line interactive SQL utility. See the - [x] concat - [x] concat_ws - [x] initcap + - [x] input_file_name - [x] left - [x] length - [x] lpad - [x] ltrim - [x] octet_length + - [x] regexp_match - [x] regexp_replace - [x] repeat - [x] replace diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 4c419d983a649..a94da698a513b 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -1881,7 +1881,7 @@ mod tests { ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); - let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); + let myfunc = |args: &[ArrayRef], _: &Schema| Ok(Arc::clone(&args[0])); let myfunc = make_scalar_function(myfunc); ctx.register_udf(create_udf( @@ -2181,7 +2181,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; ctx.register_table("t", Arc::new(provider))?; - let myfunc = |args: &[ArrayRef]| { + let myfunc = |args: &[ArrayRef], _: &Schema| { let l = &args[0] .as_any() .downcast_ref::() diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 2a0c39aa48ebd..cfb692d96e457 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -182,7 +182,7 @@ mod tests { use crate::logical_plan::*; use crate::{datasource::csv::CsvReadOptions, physical_plan::ColumnarValue}; use crate::{physical_plan::functions::ScalarFunctionImplementation, test}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Schema}; #[test] fn select_columns() -> Result<()> { @@ -304,7 +304,9 @@ mod tests { // declare the udf let my_fn: ScalarFunctionImplementation = - Arc::new(|_: &[ColumnarValue]| unimplemented!("my_fn is not implemented")); + Arc::new(|_: &[ColumnarValue], _: &Schema| { + unimplemented!("my_fn is not implemented") + }); // create and register the udf ctx.register_udf(create_udf( diff --git a/rust/datafusion/src/physical_plan/array_expressions.rs b/rust/datafusion/src/physical_plan/array_expressions.rs index a7e03b70e5d21..6c6ca035b9953 100644 --- a/rust/datafusion/src/physical_plan/array_expressions.rs +++ b/rust/datafusion/src/physical_plan/array_expressions.rs @@ -16,11 +16,11 @@ // under the License. //! Array expressions +use std::sync::Arc; use crate::error::{DataFusionError, Result}; use arrow::array::*; -use arrow::datatypes::DataType; -use std::sync::Arc; +use arrow::datatypes::{DataType, Schema}; use super::ColumnarValue; @@ -90,7 +90,7 @@ fn array_array(args: &[&dyn Array]) -> Result { } /// put values in an array. -pub fn array(values: &[ColumnarValue]) -> Result { +pub fn array(values: &[ColumnarValue], _: &Schema) -> Result { let arrays: Vec<&dyn Array> = values .iter() .map(|value| { diff --git a/rust/datafusion/src/physical_plan/crypto_expressions.rs b/rust/datafusion/src/physical_plan/crypto_expressions.rs index 8ad876b24d0ce..011c31790e045 100644 --- a/rust/datafusion/src/physical_plan/crypto_expressions.rs +++ b/rust/datafusion/src/physical_plan/crypto_expressions.rs @@ -30,7 +30,7 @@ use crate::{ }; use arrow::{ array::{Array, BinaryArray, GenericStringArray, StringOffsetSizeTrait}, - datatypes::DataType, + datatypes::{DataType, Schema}, }; use super::{string_expressions::unary_string_function, ColumnarValue}; @@ -144,7 +144,7 @@ fn md5_array( } /// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] -pub fn md5(args: &[ColumnarValue]) -> Result { +pub fn md5(args: &[ColumnarValue], _: &Schema) -> Result { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new(md5_array::(&[ @@ -178,21 +178,21 @@ pub fn md5(args: &[ColumnarValue]) -> Result { } /// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] -pub fn sha224(args: &[ColumnarValue]) -> Result { +pub fn sha224(args: &[ColumnarValue], _: &Schema) -> Result { handle(args, sha_process::, "ssh224") } /// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] -pub fn sha256(args: &[ColumnarValue]) -> Result { +pub fn sha256(args: &[ColumnarValue], _: &Schema) -> Result { handle(args, sha_process::, "sha256") } /// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] -pub fn sha384(args: &[ColumnarValue]) -> Result { +pub fn sha384(args: &[ColumnarValue], _: &Schema) -> Result { handle(args, sha_process::, "sha384") } /// crypto function that accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`] -pub fn sha512(args: &[ColumnarValue]) -> Result { +pub fn sha512(args: &[ColumnarValue], _: &Schema) -> Result { handle(args, sha_process::, "sha512") } diff --git a/rust/datafusion/src/physical_plan/csv.rs b/rust/datafusion/src/physical_plan/csv.rs index 7ee5ae3fd90b0..63b947819c262 100644 --- a/rust/datafusion/src/physical_plan/csv.rs +++ b/rust/datafusion/src/physical_plan/csv.rs @@ -306,6 +306,7 @@ impl CsvStream { let bounds = limit.map(|x| (0, x + start_line)); let reader = csv::Reader::new( + Some(filename.to_string()), file, schema, has_header, diff --git a/rust/datafusion/src/physical_plan/datetime_expressions.rs b/rust/datafusion/src/physical_plan/datetime_expressions.rs index 3d363ce97d216..6b3ed17f29d2d 100644 --- a/rust/datafusion/src/physical_plan/datetime_expressions.rs +++ b/rust/datafusion/src/physical_plan/datetime_expressions.rs @@ -25,7 +25,7 @@ use crate::{ }; use arrow::{ array::{Array, ArrayRef, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait}, - datatypes::{ArrowPrimitiveType, DataType, TimestampNanosecondType}, + datatypes::{ArrowPrimitiveType, DataType, Schema, TimestampNanosecondType}, }; use arrow::{ array::{ @@ -260,7 +260,7 @@ where } /// to_timestamp SQL function -pub fn to_timestamp(args: &[ColumnarValue]) -> Result { +pub fn to_timestamp(args: &[ColumnarValue], _: &Schema) -> Result { handle::( args, string_to_timestamp_nanos, @@ -308,7 +308,7 @@ fn date_trunc_single(granularity: &str, value: i64) -> Result { } /// date_trunc SQL function -pub fn date_trunc(args: &[ColumnarValue]) -> Result { +pub fn date_trunc(args: &[ColumnarValue], _: &Schema) -> Result { let (granularity, array) = (&args[0], &args[1]); let granularity = @@ -397,7 +397,7 @@ macro_rules! extract_date_part { } /// DATE_PART SQL function -pub fn date_part(args: &[ColumnarValue]) -> Result { +pub fn date_part(args: &[ColumnarValue], _: &Schema) -> Result { if args.len() != 2 { return Err(DataFusionError::Execution( "Expected two arguments in DATE_PART".to_string(), @@ -463,7 +463,7 @@ mod tests { let string_array = ColumnarValue::Array(Arc::new(string_builder.finish()) as ArrayRef); - let parsed_timestamps = to_timestamp(&[string_array]) + let parsed_timestamps = to_timestamp(&[string_array], &Schema::empty()) .expect("that to_timestamp parsed values without error"); if let ColumnarValue::Array(parsed_array) = parsed_timestamps { assert_eq!(parsed_array.len(), 2); @@ -543,7 +543,7 @@ mod tests { let expected_err = "Internal error: Unsupported data type Int64 for function to_timestamp"; - match to_timestamp(&[int64array]) { + match to_timestamp(&[int64array], &Schema::empty()) { Ok(_) => panic!("Expected error but got success"), Err(e) => { assert!( diff --git a/rust/datafusion/src/physical_plan/expressions/nullif.rs b/rust/datafusion/src/physical_plan/expressions/nullif.rs index 7cc58ed2318f4..4091f06e0a945 100644 --- a/rust/datafusion/src/physical_plan/expressions/nullif.rs +++ b/rust/datafusion/src/physical_plan/expressions/nullif.rs @@ -28,7 +28,7 @@ use arrow::array::{ }; use arrow::compute::kernels::boolean::nullif; use arrow::compute::kernels::comparison::{eq, eq_scalar, eq_utf8, eq_utf8_scalar}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Schema, TimeUnit}; /// Invoke a compute kernel on a primitive array and a Boolean Array macro_rules! compute_bool_array_op { @@ -71,7 +71,7 @@ macro_rules! primitive_bool_array_op { /// Args: 0 - left expr is any array /// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed. /// -pub fn nullif_func(args: &[ColumnarValue]) -> Result { +pub fn nullif_func(args: &[ColumnarValue], _: &Schema) -> Result { if args.len() != 2 { return Err(DataFusionError::Internal(format!( "{:?} args were supplied but NULLIF takes exactly two args", @@ -142,7 +142,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - let result = nullif_func(&[a, lit_array])?; + let result = nullif_func(&[a, lit_array], &Schema::empty())?; let result = result.into_array(0); let expected = Arc::new(Int32Array::from(vec![ @@ -168,7 +168,7 @@ mod tests { let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); - let result = nullif_func(&[a, lit_array])?; + let result = nullif_func(&[a, lit_array], &Schema::empty())?; let result = result.into_array(0); let expected = Arc::new(Int32Array::from(vec![ diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 56365fec1dc87..a8aa2186fd686 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -76,7 +76,7 @@ pub enum Signature { /// Scalar function pub type ScalarFunctionImplementation = - Arc Result + Send + Sync>; + Arc Result + Send + Sync>; /// A function's return type pub type ReturnTypeFunction = @@ -200,6 +200,8 @@ pub enum BuiltinScalarFunction { Upper, /// regexp_match RegexpMatch, + /// input_file_name + InputFileName, } impl fmt::Display for BuiltinScalarFunction { @@ -245,6 +247,7 @@ impl FromStr for BuiltinScalarFunction { "date_part" => BuiltinScalarFunction::DatePart, "date_trunc" => BuiltinScalarFunction::DateTrunc, "initcap" => BuiltinScalarFunction::InitCap, + "input_file_name" => BuiltinScalarFunction::InputFileName, "left" => BuiltinScalarFunction::Left, "length" => BuiltinScalarFunction::CharacterLength, "lower" => BuiltinScalarFunction::Lower, @@ -253,6 +256,7 @@ impl FromStr for BuiltinScalarFunction { "md5" => BuiltinScalarFunction::MD5, "nullif" => BuiltinScalarFunction::NullIf, "octet_length" => BuiltinScalarFunction::OctetLength, + "regexp_match" => BuiltinScalarFunction::RegexpMatch, "regexp_replace" => BuiltinScalarFunction::RegexpReplace, "repeat" => BuiltinScalarFunction::Repeat, "replace" => BuiltinScalarFunction::Replace, @@ -273,7 +277,6 @@ impl FromStr for BuiltinScalarFunction { "translate" => BuiltinScalarFunction::Translate, "trim" => BuiltinScalarFunction::Trim, "upper" => BuiltinScalarFunction::Upper, - "regexp_match" => BuiltinScalarFunction::RegexpMatch, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -295,15 +298,6 @@ pub fn return_type( // verify that this is a valid set of data types for this function data_types(&arg_types, &signature(fun))?; - if arg_types.is_empty() { - // functions currently cannot be evaluated without arguments, as they can't - // know the number of rows to return. - return Err(DataFusionError::Plan(format!( - "Function '{}' requires at least one argument", - fun - ))); - } - // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match fun { @@ -359,6 +353,7 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::InputFileName => Ok(DataType::Utf8), BuiltinScalarFunction::Left => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -733,19 +728,20 @@ pub fn create_physical_expr( // string functions BuiltinScalarFunction::Array => array_expressions::array, - BuiltinScalarFunction::Ascii => |args| match args[0].data_type() { + + BuiltinScalarFunction::Ascii => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::ascii::)(args) + make_scalar_function(string_expressions::ascii::)(args, input_schema) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::ascii::)(args) + make_scalar_function(string_expressions::ascii::)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function ascii", other, ))), }, - BuiltinScalarFunction::BitLength => |args| match &args[0] { + BuiltinScalarFunction::BitLength => |args, _| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( @@ -757,69 +753,73 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, - BuiltinScalarFunction::Btrim => |args| match args[0].data_type() { + BuiltinScalarFunction::Btrim => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function(string_expressions::btrim::)(args, input_schema) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function(string_expressions::btrim::)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function btrim", other, ))), }, - BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int32Type, - "character_length" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int64Type, - "character_length" - ); - make_scalar_function(func)(args) + BuiltinScalarFunction::CharacterLength => { + |args, input_schema| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + character_length, + Int32Type, + "character_length" + ); + make_scalar_function(func)(args, input_schema) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + character_length, + Int64Type, + "character_length" + ); + make_scalar_function(func)(args, input_schema) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function character_length", + other, + ))), } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function character_length", - other, - ))), - }, - BuiltinScalarFunction::Chr => { - |args| make_scalar_function(string_expressions::chr)(args) } + BuiltinScalarFunction::Chr => |args, input_schema| { + make_scalar_function(string_expressions::chr)(args, input_schema) + }, BuiltinScalarFunction::Concat => string_expressions::concat, - BuiltinScalarFunction::ConcatWithSeparator => { - |args| make_scalar_function(string_expressions::concat_ws)(args) - } + BuiltinScalarFunction::ConcatWithSeparator => |args, input_schema| { + make_scalar_function(string_expressions::concat_ws)(args, input_schema) + }, BuiltinScalarFunction::DatePart => datetime_expressions::date_part, BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc, - BuiltinScalarFunction::InitCap => |args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::initcap::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::initcap::)(args) + BuiltinScalarFunction::InitCap => { + |args, input_schema| match args[0].data_type() { + DataType::Utf8 => make_scalar_function( + string_expressions::initcap::, + )(args, input_schema), + DataType::LargeUtf8 => make_scalar_function( + string_expressions::initcap::, + )(args, input_schema), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function initcap", + other, + ))), } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function initcap", - other, - ))), - }, - BuiltinScalarFunction::Left => |args| match args[0].data_type() { + } + BuiltinScalarFunction::Left => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function left", @@ -827,37 +827,40 @@ pub fn create_physical_expr( ))), }, BuiltinScalarFunction::Lower => string_expressions::lower, - BuiltinScalarFunction::Lpad => |args| match args[0].data_type() { + BuiltinScalarFunction::Lpad => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function lpad", other, ))), }, - BuiltinScalarFunction::Ltrim => |args| match args[0].data_type() { + BuiltinScalarFunction::Ltrim => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::ltrim::)(args) + make_scalar_function(string_expressions::ltrim::)(args, input_schema) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::ltrim::)(args) + make_scalar_function(string_expressions::ltrim::)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function ltrim", other, ))), }, + BuiltinScalarFunction::InputFileName => |args, input_schema| { + make_scalar_function(string_expressions::input_file_name)(args, input_schema) + }, BuiltinScalarFunction::MD5 => { invoke_if_crypto_expressions_feature_flag!(md5, "md5") } BuiltinScalarFunction::NullIf => nullif_func, - BuiltinScalarFunction::OctetLength => |args| match &args[0] { + BuiltinScalarFunction::OctetLength => |args, _| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( @@ -869,126 +872,137 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, - BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_match, - i32, - "regexp_match" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_match, - i64, - "regexp_match" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_match", - other - ))), - }, - BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_replace, - i32, - "regexp_replace" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_replace, - i64, - "regexp_replace" - ); - make_scalar_function(func)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_replace", - other, - ))), - }, - BuiltinScalarFunction::Repeat => |args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::repeat::)(args) + BuiltinScalarFunction::RegexpMatch => { + |args, input_schema| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i32, + "regexp_match" + ); + make_scalar_function(func)(args, input_schema) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i64, + "regexp_match" + ); + make_scalar_function(func)(args, input_schema) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_match", + other + ))), } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::repeat::)(args) + } + BuiltinScalarFunction::RegexpReplace => { + |args, input_schema| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_replace, + i32, + "regexp_replace" + ); + make_scalar_function(func)(args, input_schema) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_replace, + i64, + "regexp_replace" + ); + make_scalar_function(func)(args, input_schema) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_replace", + other, + ))), } + } + BuiltinScalarFunction::Repeat => |args, input_schema| match args[0].data_type() { + DataType::Utf8 => make_scalar_function(string_expressions::repeat::)( + args, + input_schema, + ), + DataType::LargeUtf8 => make_scalar_function( + string_expressions::repeat::, + )(args, input_schema), other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function repeat", other, ))), }, - BuiltinScalarFunction::Replace => |args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::replace::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::replace::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function replace", - other, - ))), - }, - BuiltinScalarFunction::Reverse => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); - make_scalar_function(func)(args) + BuiltinScalarFunction::Replace => { + |args, input_schema| match args[0].data_type() { + DataType::Utf8 => make_scalar_function( + string_expressions::replace::, + )(args, input_schema), + DataType::LargeUtf8 => make_scalar_function( + string_expressions::replace::, + )(args, input_schema), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function replace", + other, + ))), } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); - make_scalar_function(func)(args) + } + BuiltinScalarFunction::Reverse => { + |args, input_schema| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + reverse, i32, "reverse" + ); + make_scalar_function(func)(args, input_schema) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + reverse, i64, "reverse" + ); + make_scalar_function(func)(args, input_schema) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function reverse", + other, + ))), } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function reverse", - other, - ))), - }, - BuiltinScalarFunction::Right => |args| match args[0].data_type() { + } + BuiltinScalarFunction::Right => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function right", other, ))), }, - BuiltinScalarFunction::Rpad => |args| match args[0].data_type() { + BuiltinScalarFunction::Rpad => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function rpad", other, ))), }, - BuiltinScalarFunction::Rtrim => |args| match args[0].data_type() { + BuiltinScalarFunction::Rtrim => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::rtrim::)(args) + make_scalar_function(string_expressions::rtrim::)(args, input_schema) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::rtrim::)(args) + make_scalar_function(string_expressions::rtrim::)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function rtrim", @@ -1007,105 +1021,111 @@ pub fn create_physical_expr( BuiltinScalarFunction::SHA512 => { invoke_if_crypto_expressions_feature_flag!(sha512, "sha512") } - BuiltinScalarFunction::SplitPart => |args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::split_part::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::split_part::)(args) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function split_part", - other, - ))), - }, - BuiltinScalarFunction::StartsWith => |args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::starts_with::)(args) + BuiltinScalarFunction::SplitPart => { + |args, input_schema| match args[0].data_type() { + DataType::Utf8 => make_scalar_function( + string_expressions::split_part::, + )(args, input_schema), + DataType::LargeUtf8 => make_scalar_function( + string_expressions::split_part::, + )(args, input_schema), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function split_part", + other, + ))), } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::starts_with::)(args) + } + BuiltinScalarFunction::StartsWith => { + |args, input_schema| match args[0].data_type() { + DataType::Utf8 => make_scalar_function( + string_expressions::starts_with::, + )(args, input_schema), + DataType::LargeUtf8 => make_scalar_function( + string_expressions::starts_with::, + )(args, input_schema), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function starts_with", + other, + ))), } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function starts_with", - other, - ))), - }, - BuiltinScalarFunction::Strpos => |args| match args[0].data_type() { + } + BuiltinScalarFunction::Strpos => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( strpos, Int32Type, "strpos" ); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!( strpos, Int64Type, "strpos" ); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function strpos", other, ))), }, - BuiltinScalarFunction::Substr => |args| match args[0].data_type() { + BuiltinScalarFunction::Substr => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } DataType::LargeUtf8 => { let func = invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); - make_scalar_function(func)(args) + make_scalar_function(func)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function substr", other, ))), }, - BuiltinScalarFunction::ToHex => |args| match args[0].data_type() { - DataType::Int32 => { - make_scalar_function(string_expressions::to_hex::)(args) - } - DataType::Int64 => { - make_scalar_function(string_expressions::to_hex::)(args) - } + BuiltinScalarFunction::ToHex => |args, input_schema| match args[0].data_type() { + DataType::Int32 => make_scalar_function( + string_expressions::to_hex::, + )(args, input_schema), + DataType::Int64 => make_scalar_function( + string_expressions::to_hex::, + )(args, input_schema), other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function to_hex", other, ))), }, BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp, - BuiltinScalarFunction::Translate => |args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - translate, - i32, - "translate" - ); - make_scalar_function(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - translate, - i64, - "translate" - ); - make_scalar_function(func)(args) + BuiltinScalarFunction::Translate => { + |args, input_schema| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + translate, + i32, + "translate" + ); + make_scalar_function(func)(args, input_schema) + } + DataType::LargeUtf8 => { + let func = invoke_if_unicode_expressions_feature_flag!( + translate, + i64, + "translate" + ); + make_scalar_function(func)(args, input_schema) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function translate", + other, + ))), } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function translate", - other, - ))), - }, - BuiltinScalarFunction::Trim => |args| match args[0].data_type() { + } + BuiltinScalarFunction::Trim => |args, input_schema| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function(string_expressions::btrim::)(args, input_schema) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::btrim::)(args) + make_scalar_function(string_expressions::btrim::)(args, input_schema) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function trim", @@ -1273,6 +1293,7 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]), ]), + BuiltinScalarFunction::InputFileName => Signature::Exact(vec![]), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). @@ -1375,9 +1396,18 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; + // if no arguments are passed create a dummy null column + let inputs = if inputs.is_empty() { + vec![ColumnarValue::Array( + ScalarValue::Utf8(None).to_array_of_size(batch.num_rows()), + )] + } else { + inputs + }; + // evaluate the function let fun = self.fun.as_ref(); - (fun)(&inputs) + fun(&inputs, &*batch.schema()) } } @@ -1385,9 +1415,9 @@ impl PhysicalExpr for ScalarFunctionExpr { /// and vice-versa after evaluation. pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, + F: Fn(&[ArrayRef], &Schema) -> Result + Sync + Send + 'static, { - Arc::new(move |args: &[ColumnarValue]| { + Arc::new(move |args: &[ColumnarValue], schema| { // first, identify if any of the arguments is an Array. If yes, store its `len`, // as any scalar will need to be converted to an array of len `len`. let len = args @@ -1408,7 +1438,7 @@ where .collect::>() }; - let result = (inner)(&args); + let result = (inner)(&args, schema); // maybe back to scalar if len.is_some() { @@ -3611,18 +3641,6 @@ mod tests { Ok(()) } - #[test] - fn test_concat_error() -> Result<()> { - let result = return_type(&BuiltinScalarFunction::Concat, &[]); - if result.is_ok() { - Err(DataFusionError::Plan( - "Function 'concat' cannot accept zero arguments".to_string(), - )) - } else { - Ok(()) - } - } - fn generic_test_array( value1: ArrayRef, value2: ArrayRef, diff --git a/rust/datafusion/src/physical_plan/math_expressions.rs b/rust/datafusion/src/physical_plan/math_expressions.rs index 382a15f8ccf6e..eb91cdb64910a 100644 --- a/rust/datafusion/src/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/physical_plan/math_expressions.rs @@ -19,7 +19,7 @@ use arrow::array::{make_array, Array, ArrayData, Float32Array, Float64Array}; use arrow::buffer::Buffer; -use arrow::datatypes::{DataType, ToByteSlice}; +use arrow::datatypes::{DataType, Schema, ToByteSlice}; use super::{ColumnarValue, ScalarValue}; use crate::error::{DataFusionError, Result}; @@ -93,7 +93,7 @@ macro_rules! unary_primitive_array_op { macro_rules! math_unary_function { ($NAME:expr, $FUNC:ident) => { /// mathematical function that accepts f32 or f64 and returns f64 - pub fn $FUNC(args: &[ColumnarValue]) -> Result { + pub fn $FUNC(args: &[ColumnarValue], _: &Schema) -> Result { unary_primitive_array_op!(&args[0], $NAME, $FUNC) } }; diff --git a/rust/datafusion/src/physical_plan/parquet.rs b/rust/datafusion/src/physical_plan/parquet.rs index fce85e3607438..b192dcb3fa122 100644 --- a/rust/datafusion/src/physical_plan/parquet.rs +++ b/rust/datafusion/src/physical_plan/parquet.rs @@ -914,9 +914,16 @@ fn read_files( loop { match batch_reader.next() { Some(Ok(batch)) => { - //println!("ParquetExec got new batch from {}", filename); + let mut metadata = HashMap::new(); + metadata.insert("filename".to_string(), filename.to_string()); + let schema = Arc::new(Schema::new_with_metadata( + batch.schema().fields().clone(), + metadata, + )); + let metadata_batch = + RecordBatch::try_new(schema, batch.columns().to_vec()).unwrap(); total_rows += batch.num_rows(); - send_result(&response_tx, Ok(batch))?; + send_result(&response_tx, Ok(metadata_batch))?; if limit.map(|l| total_rows >= l).unwrap_or(false) { break 'outer; } diff --git a/rust/datafusion/src/physical_plan/regex_expressions.rs b/rust/datafusion/src/physical_plan/regex_expressions.rs index b526e7259ef61..fe5fc84fe3f93 100644 --- a/rust/datafusion/src/physical_plan/regex_expressions.rs +++ b/rust/datafusion/src/physical_plan/regex_expressions.rs @@ -27,6 +27,7 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait}; use arrow::compute; +use arrow::datatypes::Schema; use hashbrown::HashMap; use regex::Regex; @@ -45,7 +46,10 @@ macro_rules! downcast_string_arg { } /// extract a specific group from a string column, using a regular expression -pub fn regexp_match(args: &[ArrayRef]) -> Result { +pub fn regexp_match( + args: &[ArrayRef], + _: &Schema, +) -> Result { match args.len() { 2 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), None) .map_err(DataFusionError::ArrowError), @@ -72,7 +76,10 @@ fn regex_replace_posix_groups(replacement: &str) -> String { /// Replaces substring(s) matching a POSIX regular expression. /// /// example: `regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM'` -pub fn regexp_replace(args: &[ArrayRef]) -> Result { +pub fn regexp_replace( + args: &[ArrayRef], + _: &Schema, +) -> Result { // creating Regex is expensive so create hashmap for memoization let mut patterns: HashMap = HashMap::new(); diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 882fe30502fdf..9334df4124b1d 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -33,7 +33,7 @@ use arrow::{ Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, PrimitiveArray, StringArray, StringOffsetSizeTrait, }, - datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, + datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType, Schema}, }; use super::ColumnarValue; @@ -174,7 +174,10 @@ where /// Returns the numeric code of the first character of the argument. /// ascii('x') = 120 -pub fn ascii(args: &[ArrayRef]) -> Result { +pub fn ascii( + args: &[ArrayRef], + _: &Schema, +) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); let result = string_array @@ -192,7 +195,10 @@ pub fn ascii(args: &[ArrayRef]) -> Result { /// Removes the longest string containing only characters in characters (a space by default) from the start and end of string. /// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { +pub fn btrim( + args: &[ArrayRef], + _: &Schema, +) -> Result { match args.len() { 1 => { let string_array = downcast_string_arg!(args[0], "string", T); @@ -240,7 +246,7 @@ pub fn btrim(args: &[ArrayRef]) -> Result { /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' -pub fn chr(args: &[ArrayRef]) -> Result { +pub fn chr(args: &[ArrayRef], _: &Schema) -> Result { let integer_array = downcast_arg!(args[0], "integer", Int64Array); // first map is the iterator, second is for the `Option<_>` @@ -271,7 +277,7 @@ pub fn chr(args: &[ArrayRef]) -> Result { /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' -pub fn concat(args: &[ColumnarValue]) -> Result { +pub fn concat(args: &[ColumnarValue], _: &Schema) -> Result { // do not accept 0 arguments. if args.is_empty() { return Err(DataFusionError::Internal(format!( @@ -331,7 +337,7 @@ pub fn concat(args: &[ColumnarValue]) -> Result { /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' -pub fn concat_ws(args: &[ArrayRef]) -> Result { +pub fn concat_ws(args: &[ArrayRef], _: &Schema) -> Result { // downcast all arguments to strings let args = downcast_vec!(args, StringArray).collect::>>()?; @@ -370,7 +376,10 @@ pub fn concat_ws(args: &[ArrayRef]) -> Result { /// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. /// initcap('hi THOMAS') = 'Hi Thomas' -pub fn initcap(args: &[ArrayRef]) -> Result { +pub fn initcap( + args: &[ArrayRef], + _: &Schema, +) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); // first map is the iterator, second is for the `Option<_>` @@ -400,13 +409,16 @@ pub fn initcap(args: &[ArrayRef]) -> Result /// Converts the string to all lower case. /// lower('TOM') = 'tom' -pub fn lower(args: &[ColumnarValue]) -> Result { +pub fn lower(args: &[ColumnarValue], _: &Schema) -> Result { handle(args, |string| string.to_ascii_lowercase(), "lower") } /// Removes the longest string containing only characters in characters (a space by default) from the start of string. /// ltrim('zzzytest', 'xyz') = 'test' -pub fn ltrim(args: &[ArrayRef]) -> Result { +pub fn ltrim( + args: &[ArrayRef], + _: &Schema, +) -> Result { match args.len() { 1 => { let string_array = downcast_string_arg!(args[0], "string", T); @@ -445,7 +457,10 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' -pub fn repeat(args: &[ArrayRef]) -> Result { +pub fn repeat( + args: &[ArrayRef], + _: &Schema, +) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); let number_array = downcast_arg!(args[1], "number", Int64Array); @@ -463,7 +478,10 @@ pub fn repeat(args: &[ArrayRef]) -> Result { /// Replaces all occurrences in string of substring from with substring to. /// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef' -pub fn replace(args: &[ArrayRef]) -> Result { +pub fn replace( + args: &[ArrayRef], + _: &Schema, +) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); let from_array = downcast_string_arg!(args[1], "from", T); let to_array = downcast_string_arg!(args[2], "to", T); @@ -481,9 +499,23 @@ pub fn replace(args: &[ArrayRef]) -> Result Ok(Arc::new(result) as ArrayRef) } +/// Returns the name of the file being read or null if not available. +/// input_file_name() = './alltypes_plain.parquet' +pub fn input_file_name(args: &[ArrayRef], schema: &Schema) -> Result { + let result = (0..args[0].len()) + .into_iter() + .map(|_| schema.metadata().get("filename")) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + /// Removes the longest string containing only characters in characters (a space by default) from the end of string. /// rtrim('testxxzx', 'xyz') = 'test' -pub fn rtrim(args: &[ArrayRef]) -> Result { +pub fn rtrim( + args: &[ArrayRef], + _: &Schema, +) -> Result { match args.len() { 1 => { let string_array = downcast_string_arg!(args[0], "string", T); @@ -522,7 +554,10 @@ pub fn rtrim(args: &[ArrayRef]) -> Result { /// Splits string at occurrences of delimiter and returns the n'th field (counting from one). /// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' -pub fn split_part(args: &[ArrayRef]) -> Result { +pub fn split_part( + args: &[ArrayRef], + _: &Schema, +) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); let delimiter_array = downcast_string_arg!(args[1], "delimiter", T); let n_array = downcast_arg!(args[2], "n", Int64Array); @@ -554,7 +589,10 @@ pub fn split_part(args: &[ArrayRef]) -> Result(args: &[ArrayRef]) -> Result { +pub fn starts_with( + args: &[ArrayRef], + _: &Schema, +) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); let prefix_array = downcast_string_arg!(args[1], "prefix", T); @@ -572,7 +610,7 @@ pub fn starts_with(args: &[ArrayRef]) -> Result(args: &[ArrayRef]) -> Result +pub fn to_hex(args: &[ArrayRef], _: &Schema) -> Result where T::Native: StringOffsetSizeTrait, { @@ -590,6 +628,6 @@ where /// Converts the string to all upper case. /// upper('tom') = 'TOM' -pub fn upper(args: &[ColumnarValue]) -> Result { +pub fn upper(args: &[ColumnarValue], _: &Schema) -> Result { handle(args, |string| string.to_ascii_uppercase(), "upper") } diff --git a/rust/datafusion/src/physical_plan/unicode_expressions.rs b/rust/datafusion/src/physical_plan/unicode_expressions.rs index 787ea7ea26730..18f17ebf1bc5c 100644 --- a/rust/datafusion/src/physical_plan/unicode_expressions.rs +++ b/rust/datafusion/src/physical_plan/unicode_expressions.rs @@ -30,7 +30,7 @@ use arrow::{ array::{ ArrayRef, GenericStringArray, Int64Array, PrimitiveArray, StringOffsetSizeTrait, }, - datatypes::{ArrowNativeType, ArrowPrimitiveType}, + datatypes::{ArrowNativeType, ArrowPrimitiveType, Schema}, }; use hashbrown::HashMap; use unicode_segmentation::UnicodeSegmentation; @@ -63,7 +63,10 @@ macro_rules! downcast_arg { /// Returns number of characters in the string. /// character_length('josé') = 4 -pub fn character_length(args: &[ArrayRef]) -> Result +pub fn character_length( + args: &[ArrayRef], + _: &Schema, +) -> Result where T::Native: StringOffsetSizeTrait, { @@ -90,7 +93,7 @@ where /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. /// left('abcde', 2) = 'ab' -pub fn left(args: &[ArrayRef]) -> Result { +pub fn left(args: &[ArrayRef], _: &Schema) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); let n_array = downcast_arg!(args[1], "n", Int64Array); @@ -124,7 +127,7 @@ pub fn left(args: &[ArrayRef]) -> Result { /// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). /// lpad('hi', 5, 'xy') = 'xyxhi' -pub fn lpad(args: &[ArrayRef]) -> Result { +pub fn lpad(args: &[ArrayRef], _: &Schema) -> Result { match args.len() { 2 => { let string_array = downcast_string_arg!(args[0], "string", T); @@ -213,7 +216,10 @@ pub fn lpad(args: &[ArrayRef]) -> Result { /// Reverses the order of the characters in the string. /// reverse('abcde') = 'edcba' -pub fn reverse(args: &[ArrayRef]) -> Result { +pub fn reverse( + args: &[ArrayRef], + _: &Schema, +) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); let result = string_array @@ -228,7 +234,10 @@ pub fn reverse(args: &[ArrayRef]) -> Result /// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. /// right('abcde', 2) = 'de' -pub fn right(args: &[ArrayRef]) -> Result { +pub fn right( + args: &[ArrayRef], + _: &Schema, +) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); let n_array = downcast_arg!(args[1], "n", Int64Array); @@ -276,7 +285,7 @@ pub fn right(args: &[ArrayRef]) -> Result { /// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. /// rpad('hi', 5, 'xy') = 'hixyx' -pub fn rpad(args: &[ArrayRef]) -> Result { +pub fn rpad(args: &[ArrayRef], _: &Schema) -> Result { match args.len() { 2 => { let string_array = downcast_string_arg!(args[0], "string", T); @@ -353,7 +362,7 @@ pub fn rpad(args: &[ArrayRef]) -> Result { /// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) /// strpos('high', 'ig') = 2 -pub fn strpos(args: &[ArrayRef]) -> Result +pub fn strpos(args: &[ArrayRef], _: &Schema) -> Result where T::Native: StringOffsetSizeTrait, { @@ -412,7 +421,10 @@ where /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) /// substr('alphabet', 3) = 'phabet' /// substr('alphabet', 3, 2) = 'ph' -pub fn substr(args: &[ArrayRef]) -> Result { +pub fn substr( + args: &[ArrayRef], + _: &Schema, +) -> Result { match args.len() { 2 => { let string_array = downcast_string_arg!(args[0], "string", T); @@ -489,7 +501,10 @@ pub fn substr(args: &[ArrayRef]) -> Result { /// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. /// translate('12345', '143', 'ax') = 'a2x5' -pub fn translate(args: &[ArrayRef]) -> Result { +pub fn translate( + args: &[ArrayRef], + _: &Schema, +) -> Result { let string_array = downcast_string_arg!(args[0], "string", T); let from_array = downcast_string_arg!(args[1], "from", T); let to_array = downcast_string_arg!(args[2], "to", T); diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index f3ea7c9e34d84..b809ee616e512 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -2704,7 +2704,7 @@ mod tests { fn get_function_meta(&self, name: &str) -> Option> { let f: ScalarFunctionImplementation = - Arc::new(|_| Err(DataFusionError::NotImplemented("".to_string()))); + Arc::new(|_, _| Err(DataFusionError::NotImplemented("".to_string()))); match name { "my_sqrt" => Some(Arc::new(create_udf( "my_sqrt", diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index e92bf5593f3d3..790eca68f8570 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -120,6 +120,31 @@ async fn parquet_query() { assert_eq!(expected, actual); } +#[tokio::test] +async fn parquet_input_file_name() { + let mut ctx = ExecutionContext::new(); + register_alltypes_parquet(&mut ctx); + // NOTE that string_col is actually a binary column and does not have the UTF8 logical type + // so we need an explicit cast + let sql = "SELECT id, input_file_name() FROM alltypes_plain"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["4"], + vec!["5"], + vec!["6"], + vec!["7"], + vec!["2"], + vec!["3"], + vec!["0"], + vec!["1"], + ]; + + for i in 0..actual.len() { + assert_eq!(actual[i][0], expected[i][0]); + assert!(actual[i][1].contains("alltypes_plain.parquet")); + } +} + #[tokio::test] async fn parquet_single_nan_schema() { let mut ctx = ExecutionContext::new(); @@ -275,6 +300,31 @@ async fn csv_count_star() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_input_file_name() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "SELECT c1, input_file_name() FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["c"], + vec!["d"], + vec!["b"], + vec!["a"], + vec!["b"], + vec!["b"], + vec!["e"], + vec!["a"], + vec!["d"], + vec!["a"], + ]; + for i in 0..expected.len() { + assert_eq!(actual[i][0], expected[i][0]); + assert!(actual[i][1].contains("aggregate_test_100.csv")); + } + Ok(()) +} + #[tokio::test] async fn csv_query_with_predicate() -> Result<()> { let mut ctx = ExecutionContext::new(); @@ -575,7 +625,7 @@ fn create_ctx() -> Result { Ok(ctx) } -fn custom_sqrt(args: &[ColumnarValue]) -> Result { +fn custom_sqrt(args: &[ColumnarValue], _: &Schema) -> Result { let arg = &args[0]; if let ColumnarValue::Array(v) = arg { let input = v @@ -2670,7 +2720,7 @@ async fn test_cast_expressions() -> Result<()> { #[tokio::test] async fn test_cast_expressions_error() -> Result<()> { // sin(utf8) should error - let mut ctx = create_ctx()?; + let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "SELECT CAST(c1 AS INT) FROM aggregate_test_100"; let plan = ctx.create_logical_plan(&sql).unwrap(); @@ -2686,6 +2736,18 @@ async fn test_cast_expressions_error() -> Result<()> { )) } } - Ok(()) } + +#[tokio::test] +async fn input_file_name() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let sql = "SELECT + input_file_name() + "; + let actual = execute(&mut ctx, sql).await; + + let expected = vec![vec!["NULL"]]; + assert_eq!(expected, actual); + Ok(()) +} \ No newline at end of file