Skip to content

Commit

Permalink
fix: Optimize read_side_padding (apache#772)
Browse files Browse the repository at this point in the history
## Which issue does this PR close?

## Rationale for this change

This PR improves read_side_padding that is used for CHAR() schema

## What changes are included in this PR?

Optimized spark_read_side_padding

## How are these changes tested?

Added tests

(cherry picked from commit 457d9d1)
  • Loading branch information
kazuyukitanimura authored and huaxingao committed Aug 22, 2024
1 parent 8f16a4e commit 1886b57
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 44 deletions.
1 change: 0 additions & 1 deletion native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use datafusion_comet_spark_expr::scalar_funcs::hash_expressions::{
};
use datafusion_comet_spark_expr::scalar_funcs::{
spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
spark_murmur3_hash, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, spark_xxhash64,
SparkChrFunc,
spark_murmur3_hash, spark_read_side_padding, spark_round, spark_unhex, spark_unscaled_value,
spark_xxhash64, SparkChrFunc,
};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::registry::FunctionRegistry;
Expand Down Expand Up @@ -67,9 +67,9 @@ pub fn create_comet_physical_fun(
"floor" => {
make_comet_scalar_udf!("floor", spark_floor, data_type)
}
"rpad" => {
let func = Arc::new(spark_rpad);
make_comet_scalar_udf!("rpad", func, without data_type)
"read_side_padding" => {
let func = Arc::new(spark_read_side_padding);
make_comet_scalar_udf!("read_side_padding", func, without data_type)
}
"round" => {
make_comet_scalar_udf!("round", spark_round, data_type)
Expand Down
15 changes: 10 additions & 5 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1724,11 +1724,16 @@ impl PhysicalPlanner {

let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) {
Some(t) => t,
None => self
.session_ctx
.udf(fun_name)?
.inner()
.return_type(&input_expr_types)?,
None => {
let fun_name = match fun_name.as_str() {
"read_side_padding" => "rpad", // use the same return type as rpad
other => other,
};
self.session_ctx
.udf(fun_name)?
.inner()
.return_type(&input_expr_types)?
}
};

let fun_expr =
Expand Down
1 change: 0 additions & 1 deletion native/spark-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ chrono-tz = { workspace = true }
num = { workspace = true }
regex = { workspace = true }
thiserror = { workspace = true }
unicode-segmentation = "1.11.0"

[dev-dependencies]
arrow-data = {workspace = true}
Expand Down
62 changes: 32 additions & 30 deletions native/spark-expr/src/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::{cmp::min, sync::Arc};

use arrow::{
array::{
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray,
Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::builder::GenericStringBuilder;
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
Expand All @@ -35,7 +34,8 @@ use num::{
integer::{div_ceil, div_floor},
BigInt, Signed, ToPrimitive,
};
use unicode_segmentation::UnicodeSegmentation;
use std::fmt::Write;
use std::{cmp::min, sync::Arc};

mod unhex;
pub use unhex::spark_unhex;
Expand Down Expand Up @@ -387,52 +387,54 @@ pub fn spark_round(
}

/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
match args[0].data_type() {
DataType::Utf8 => spark_rpad_internal::<i32>(array, *length),
DataType::LargeUtf8 => spark_rpad_internal::<i64>(array, *length),
match array.data_type() {
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(array, *length),
// TODO: handle Dictionary types
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function rpad",
"Unsupported data type {other:?} for function read_side_padding",
))),
}
}
other => Err(DataFusionError::Internal(format!(
"Unsupported arguments {other:?} for function rpad",
"Unsupported arguments {other:?} for function read_side_padding",
))),
}
}

fn spark_rpad_internal<T: OffsetSizeTrait>(
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
array: &ArrayRef,
length: i32,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
let length = 0.max(length) as usize;
let space_string = " ".repeat(length);

let mut builder =
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);

let result = string_array
.iter()
.map(|string| match string {
for string in string_array.iter() {
match string {
Some(string) => {
let length = if length < 0 { 0 } else { length as usize };
if length == 0 {
Ok(Some("".to_string()))
// It looks Spark's UTF8String is closer to chars rather than graphemes
// https://stackoverflow.com/a/46290728
let char_len = string.chars().count();
if length <= char_len {
builder.append_value(string);
} else {
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
if length < graphemes.len() {
Ok(Some(string.to_string()))
} else {
let mut s = string.to_string();
s.push_str(" ".repeat(length - graphemes.len()).as_str());
Ok(Some(s))
}
// write_str updates only the value buffer, not null nor offset buffer
// This is convenient for concatenating str(s)
builder.write_str(string)?;
builder.append_value(&space_string[char_len..]);
}
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<T>, DataFusionError>>()?;
Ok(ColumnarValue::Array(Arc::new(result)))
_ => builder.append_null(),
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}

// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2178,7 +2178,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
// char types. Use rpad to achieve the behavior.
// char types.
// See https://github.com/apache/spark/pull/38151
case s: StaticInvoke
if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
Expand All @@ -2194,7 +2194,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim

if (argsExpr.forall(_.isDefined)) {
val builder = ExprOuterClass.ScalarFunc.newBuilder()
builder.setFunc("rpad")
builder.setFunc("read_side_padding")
argsExpr.foreach(arg => builder.addArgs(arg.get))

Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
Expand Down
7 changes: 7 additions & 0 deletions spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
cd_gender
FROM customer_demographics
WHERE
cd_gender = 'M' AND
cd_marital_status = 'S' AND
cd_education_status = 'College'
14 changes: 14 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1911,6 +1911,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("readSidePadding") {
// https://stackoverflow.com/a/46290728
val table = "test"
withTable(table) {
sql(s"create table $table(col1 CHAR(2)) using parquet")
sql(s"insert into $table values('é')") // unicode 'e\\u{301}'
sql(s"insert into $table values('é')") // unicode '\\u{e9}'
sql(s"insert into $table values('')")
sql(s"insert into $table values('ab')")

checkSparkAnswerAndOperator(s"SELECT * FROM $table")
}
}

test("isnan") {
Seq("true", "false").foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ object CometTPCDSMicroBenchmark extends CometTPCQueryBenchmarkBase {
"agg_sum_integers_no_grouping",
"case_when_column_or_null",
"case_when_scalar",
"char_type",
"filter_highly_selective",
"filter_less_selective",
"if_column_or_null",
Expand Down

0 comments on commit 1886b57

Please sign in to comment.