diff --git a/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala b/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala index 14e0d4e0970..9dc4de9086f 100644 --- a/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala +++ b/datagen/src/main/scala/org/apache/spark/sql/tests/datagen/bigDataGen.scala @@ -29,7 +29,8 @@ import scala.util.Random import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, XXH64} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression, XXH64} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.functions.{approx_count_distinct, avg, coalesce, col, count, lit, stddev, struct, transform, udf, when} import org.apache.spark.sql.types._ @@ -577,7 +578,8 @@ object DataGen { case class DataGenExpr(child: Expression, override val dataType: DataType, canHaveNulls: Boolean, - f: GeneratorFunction) extends DataGenExprBase { + f: GeneratorFunction) + extends UnaryExpression with ExpectsInputTypes with CodegenFallback { override def nullable: Boolean = canHaveNulls @@ -587,6 +589,9 @@ case class DataGenExpr(child: Expression, val rowLoc = new RowLocation(child.eval(input).asInstanceOf[Long]) f(rowLoc) } + + override def withNewChildInternal(newChild: Expression): Expression = + DataGenExpr(newChild, dataType, canHaveNulls, f) } abstract class CommonDataGen( @@ -2670,7 +2675,9 @@ object ColumnGen { dataType: DataType, nullable: Boolean, gen: GeneratorFunction): Column = { - Column(DataGenExpr(rowNumber.expr, dataType, nullable, gen)) + val rowNumberExpr = DataGenExprShims.columnToExpr(rowNumber) + val expr = DataGenExpr(rowNumberExpr, dataType, nullable, gen) + DataGenExprShims.exprToColumn(expr) } } diff --git a/datagen/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/datagen/DataGenExprBase.scala b/datagen/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/DataGenExprShims.scala similarity index 70% rename from datagen/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/datagen/DataGenExprBase.scala rename to datagen/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/DataGenExprShims.scala index ccbb03c4faa..39e36388a52 100644 --- a/datagen/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/datagen/DataGenExprBase.scala +++ b/datagen/src/main/spark320/scala/org/apache/spark/sql/tests/datagen/DataGenExprShims.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,14 +37,13 @@ {"spark": "343"} {"spark": "350"} {"spark": "351"} -{"spark": "400"} spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.tests.datagen -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.Expression -trait DataGenExprBase extends UnaryExpression with ExpectsInputTypes with CodegenFallback { - override def withNewChildInternal(newChild: Expression): Expression = - legacyWithNewChildren(Seq(newChild)) +object DataGenExprShims { + def columnToExpr(c: Column): Expression = c.expr + def exprToColumn(e: Expression): Column = Column(e) } diff --git a/datagen/src/main/spark400/scala/org/apache/spark/sql/tests/datagen/DataGenExprShims.scala b/datagen/src/main/spark400/scala/org/apache/spark/sql/tests/datagen/DataGenExprShims.scala new file mode 100644 index 00000000000..2884968660d --- /dev/null +++ b/datagen/src/main/spark400/scala/org/apache/spark/sql/tests/datagen/DataGenExprShims.scala @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.tests.datagen + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} + +object DataGenExprShims { + def columnToExpr(c: Column): Expression = c + def exprToColumn(e: Expression): Column = e +} diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/ProjectExprSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/ProjectExprSuite.scala index dd3c832adb4..3c743be5015 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/ProjectExprSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/ProjectExprSuite.scala @@ -26,11 +26,12 @@ import com.nvidia.spark.rapids.jni.RmmSpark import org.mockito.Mockito.{mock, spy, when} import org.apache.spark.SparkConf -import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.tests.datagen.DataGenExprShims import org.apache.spark.sql.types._ @@ -187,12 +188,15 @@ class ProjectExprSuite extends SparkQueryCompareTestSuite { lit(Array("a", "b", null, "")), lit(Array(Array(1, 2), null, Array(3, 4))), lit(Array(Array(Array(1, 2), Array(2, 3), null), null)), - Column(Literal.create(Array(Row(1, "s1"), Row(2, "s2"), null), - ArrayType(StructType( - Array(StructField("id", IntegerType), StructField("name", StringType)))))), - Column(Literal.create(List(BigDecimal(123L, 2), BigDecimal(-1444L, 2)), + DataGenExprShims.exprToColumn( + Literal.create(Array(Row(1, "s1"), Row(2, "s2"), null), + ArrayType(StructType( + Array(StructField("id", IntegerType), StructField("name", StringType)))))), + DataGenExprShims.exprToColumn( + Literal.create(List(BigDecimal(123L, 2), BigDecimal(-1444L, 2)), ArrayType(DecimalType(10, 2)))), - Column(Literal.create(List(BigDecimal("1234567890123456789012345678")), + DataGenExprShims.exprToColumn( + Literal.create(List(BigDecimal("1234567890123456789012345678")), ArrayType(DecimalType(30, 2)))) ) .selectExpr("*", "array(null)", "array(array(null))", "array()") diff --git a/tests/src/test/scala/org/apache/spark/sql/timezone/TimeZonePerfUtils.scala b/tests/src/test/scala/org/apache/spark/sql/timezone/TimeZonePerfUtils.scala index de3a8e0e755..1e0cba47b78 100644 --- a/tests/src/test/scala/org/apache/spark/sql/timezone/TimeZonePerfUtils.scala +++ b/tests/src/test/scala/org/apache/spark/sql/timezone/TimeZonePerfUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,6 +54,7 @@ case class StringGenFunc(strings: Array[String]) extends DefaultGeneratorFunctio object TimeZonePerfUtils { def createColumn(idCol: Column, t: DataType, func: GeneratorFunction): Column = { - Column(DataGenExpr(idCol.expr, t, false, func)) + val expr = DataGenExprShims.columnToExpr(idCol) + DataGenExprShims.exprToColumn(DataGenExpr(expr, t, false, func)) } }