Skip to content

Commit

Permalink
[SPARK-44532][CONNECT][SQL] Move ArrowUtils to sql/api
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR moves `ArrowUtils` to `sql/api`. One method used for configuring python's arrow runner has been moved to `ArrowPythonRunner `.

### Why are the changes needed?
ArrowUtils is used by connect's direct Arrow encoding (and a lot of other things in sql). We want to remove the connect scala client's catalyst dependency. We need to move ArrowUtil in order to do so.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existing tests.

Closes apache#42137 from hvanhovell/SPARK-44532.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
hvanhovell committed Jul 25, 2023
1 parent 4a75a02 commit 307e46c
Show file tree
Hide file tree
Showing 19 changed files with 86 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors}
import org.apache.spark.sql.types.Decimal

/**
Expand Down Expand Up @@ -341,7 +341,7 @@ object ArrowDeserializers {
}

case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType)
throw DataTypeErrors.unsupportedDataTypeError(encoder.dataType)

case _ =>
throw new RuntimeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.util.ArrowUtils

Expand Down Expand Up @@ -439,7 +439,7 @@ object ArrowSerializer {
}

case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType)
throw DataTypeErrors.unsupportedDataTypeError(encoder.dataType)

case _ =>
throw new RuntimeException(s"Unsupported Encoder($encoder)/Vector($v) combination.")
Expand Down
8 changes: 8 additions & 0 deletions sql/api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@
<groupId>org.antlr</groupId>
<artifactId>antlr4-runtime</artifactId>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory-netty</artifactId>
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.errors

import org.apache.arrow.vector.types.pojo.ArrowType

import org.apache.spark.SparkUnsupportedOperationException

trait ArrowErrors {

def unsupportedArrowTypeError(typeName: ArrowType): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "UNSUPPORTED_ARROWTYPE",
messageParameters = Map("typeName" -> typeName.toString))
}

def duplicatedFieldNameInArrowStructError(
fieldNames: Seq[String]): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT",
messageParameters = Map("fieldNames" -> fieldNames.mkString("[", ", ", "]")))
}
}

object ArrowErrors extends ArrowErrors
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,10 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
messageParameters = Map("operation" -> operation),
cause = null)
}

def unsupportedDataTypeError(typeName: DataType): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "UNSUPPORTED_DATATYPE",
messageParameters = Map("typeName" -> toSQLType(typeName)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.errors.{ArrowErrors, DataTypeErrors}
import org.apache.spark.sql.types._

private[sql] object ArrowUtils {
Expand Down Expand Up @@ -61,7 +60,7 @@ private[sql] object ArrowUtils {
case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
case _ =>
throw QueryExecutionErrors.unsupportedDataTypeError(dt)
throw DataTypeErrors.unsupportedDataTypeError(dt)
}

def fromArrowType(dt: ArrowType): DataType = dt match {
Expand All @@ -86,7 +85,7 @@ private[sql] object ArrowUtils {
case ArrowType.Null.INSTANCE => NullType
case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType()
case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType()
case _ => throw QueryExecutionErrors.unsupportedArrowTypeError(dt)
case _ => throw ArrowErrors.unsupportedArrowTypeError(dt)
}

/** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */
Expand Down Expand Up @@ -172,16 +171,6 @@ private[sql] object ArrowUtils {
}.toArray)
}

/** Return Map with conf settings to be used in ArrowPythonRunner */
def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone)
val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
conf.pandasGroupedMapAssignColumnsByName.toString)
val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
conf.arrowSafeTypeConversion.toString)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
}

private def deduplicateFieldNames(
dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match {
case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames)
Expand All @@ -190,7 +179,7 @@ private[sql] object ArrowUtils {
st.names
} else {
if (errorOnDuplicatedFieldNames) {
throw QueryExecutionErrors.duplicatedFieldNameInArrowStructError(st.names)
throw ArrowErrors.duplicatedFieldNameInArrowStructError(st.names)
}
val genNawName = st.names.groupBy(identity).map {
case (name, names) if names.length > 1 =>
Expand Down
8 changes: 0 additions & 8 deletions sql/catalyst/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,6 @@
<artifactId>univocity-parsers</artifactId>
<type>jar</type>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory-netty</artifactId>
</dependency>
<dependency>
<groupId>org.apache.datasketches</groupId>
<artifactId>datasketches-java</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import java.time.temporal.ChronoField
import java.util.concurrent.TimeoutException

import com.fasterxml.jackson.core.{JsonParser, JsonToken}
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.hadoop.fs.{FileAlreadyExistsException, FileStatus, Path}
import org.apache.hadoop.fs.permission.FsPermission
import org.codehaus.commons.compiler.{CompileException, InternalCompilerException}
Expand Down Expand Up @@ -1142,25 +1141,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
messageParameters = Map("cost" -> cost))
}

def unsupportedArrowTypeError(typeName: ArrowType): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "UNSUPPORTED_ARROWTYPE",
messageParameters = Map("typeName" -> typeName.toString))
}

def unsupportedDataTypeError(typeName: DataType): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "UNSUPPORTED_DATATYPE",
messageParameters = Map("typeName" -> toSQLType(typeName)))
}

def duplicatedFieldNameInArrowStructError(
fieldNames: Seq[String]): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT",
messageParameters = Map("fieldNames" -> fieldNames.mkString("[", ", ", "]")))
}

def notSupportTypeError(dataType: DataType): Throwable = {
new SparkException(
errorClass = "_LEGACY_ERROR_TEMP_2100",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.arrow.vector.complex._

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils

Expand Down Expand Up @@ -83,7 +83,7 @@ object ArrowWriter {
case (_: YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector)
case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector)
case (dt, _) =>
throw QueryExecutionErrors.unsupportedDataTypeError(dt)
throw DataTypeErrors.unsupportedDataTypeError(dt)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.V1WriteCommand
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -277,7 +277,7 @@ private object RowToColumnConverter {
case dt: DecimalType => new DecimalConverter(dt)
case mt: MapType => MapConverter(getConverterForType(mt.keyType, nullable = false),
getConverterForType(mt.valueType, mt.valueContainsNull))
case unknown => throw QueryExecutionErrors.unsupportedDataTypeError(unknown)
case unknown => throw DataTypeErrors.unsupportedDataTypeError(unknown)
}

if (nullable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -102,7 +101,7 @@ case class AggregateInPandasExec(

val sessionLocalTimeZone = conf.sessionLocalTimeZone
val largeVarTypes = conf.arrowUseLargeVarTypes
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)

val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

/**
* Grouped a iterator into batches.
Expand Down Expand Up @@ -75,7 +74,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
evalType,
conf.sessionLocalTimeZone,
conf.arrowUseLargeVarTypes,
ArrowUtils.getPythonRunnerConfMap(conf),
ArrowPythonRunner.getPythonRunnerConfMap(conf),
pythonMetrics,
jobArtifactUUID)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}

/**
Expand All @@ -49,7 +48,7 @@ case class ArrowEvalPythonUDTFExec(
private val batchSize = conf.arrowMaxRecordsPerBatch
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val largeVarTypes = conf.arrowUseLargeVarTypes
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)

override protected def evaluate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,15 @@ class ArrowPythonRunner(
"Pandas execution requires more than 4 bytes. Please set higher buffer. " +
s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
}

object ArrowPythonRunner {
/** Return Map with conf settings to be used in ArrowPythonRunner */
def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone)
val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
conf.pandasGroupedMapAssignColumnsByName.toString)
val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
conf.arrowSafeTypeConversion.toString)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan}
import org.apache.spark.sql.execution.python.PandasGroupUtils._
import org.apache.spark.sql.util.ArrowUtils


/**
Expand Down Expand Up @@ -58,7 +57,7 @@ case class FlatMapCoGroupsInPandasExec(
extends SparkPlan with BinaryExecNode with PythonSQLMetrics {

private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
private val pandasFunction = func.asInstanceOf[PythonUDF].func
private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.python.PandasGroupUtils._
import org.apache.spark.sql.util.ArrowUtils


/**
Expand Down Expand Up @@ -55,7 +54,7 @@ case class FlatMapGroupsInPandasExec(

private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val largeVarTypes = conf.arrowUseLargeVarTypes
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
private val pandasFunction = func.asInstanceOf[PythonUDF].func
private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExec
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.CompletionIterator

/**
Expand Down Expand Up @@ -81,7 +80,7 @@ case class FlatMapGroupsInPandasWithStateExec(
override def output: Seq[Attribute] = outAttributes

private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)

private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func
private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.UnaryExecNode
import org.apache.spark.sql.util.ArrowUtils

/**
* A relation produced by applying a function that takes an iterator of batches
Expand All @@ -46,7 +45,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def doExecute(): RDD[InternalRow] = {
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
val pythonFunction = func.asInstanceOf[PythonUDF].func
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
val evaluatorFactory = new MapInBatchEvaluatorFactory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType}
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils

class WindowInPandasEvaluatorFactory(
Expand Down Expand Up @@ -162,7 +161,8 @@ class WindowInPandasEvaluatorFactory(

private val udfWindowBoundTypes = pyFuncs.indices.map(i =>
frameWindowBoundTypes(expressionIndexToFrameIndex(i)))
private val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf)
private val pythonRunnerConf: Map[String, String] =
(ArrowPythonRunner.getPythonRunnerConfMap(conf)
+ (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(",")))

// Filter child output attributes down to only those that are UDF inputs.
Expand Down

0 comments on commit 307e46c

Please sign in to comment.