Skip to content

Commit

Permalink
[SEDONA-607] Fix error message enhancements for geometry functions (#…
Browse files Browse the repository at this point in the history
…1555)

* Fix error message enhancements for geometry functions

* Fix compilation for scala 2.13
  • Loading branch information
Kontinuation committed Aug 21, 2024
1 parent 0e3d4c9 commit 1737174
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,21 +160,20 @@ case class ST_GeomFromWKB(inputExpressions: Seq[Expression])
override def nullable: Boolean = true

override def eval(inputRow: InternalRow): Any = {
val arg = inputExpressions.head.eval(inputRow)
try {
(inputExpressions.head.eval(inputRow)) match {
case (geomString: UTF8String) => {
arg match {
case geomString: UTF8String =>
// Parse UTF-8 encoded wkb string
Constructors.geomFromText(geomString.toString, FileDataSplitter.WKB).toGenericArrayData
}
case (wkb: Array[Byte]) => {
case wkb: Array[Byte] =>
// convert raw wkb byte array to geometry
Constructors.geomFromWKB(wkb).toGenericArrayData
}
case null => null
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(getClass.getSimpleName, Seq(arg), e)
}
}

Expand All @@ -201,21 +200,20 @@ case class ST_GeomFromEWKB(inputExpressions: Seq[Expression])
override def nullable: Boolean = true

override def eval(inputRow: InternalRow): Any = {
val arg = inputExpressions.head.eval(inputRow)
try {
(inputExpressions.head.eval(inputRow)) match {
case (geomString: UTF8String) => {
arg match {
case geomString: UTF8String =>
// Parse UTF-8 encoded wkb string
Constructors.geomFromText(geomString.toString, FileDataSplitter.WKB).toGenericArrayData
}
case (wkb: Array[Byte]) => {
case wkb: Array[Byte] =>
// convert raw wkb byte array to geometry
Constructors.geomFromWKB(wkb).toGenericArrayData
}
case null => null
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(getClass.getSimpleName, Seq(arg), e)
}
}

Expand Down Expand Up @@ -267,7 +265,10 @@ case class ST_LineFromWKB(inputExpressions: Seq[Expression])
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(wkb, srid),
e)
}
}

Expand Down Expand Up @@ -321,7 +322,10 @@ case class ST_LinestringFromWKB(inputExpressions: Seq[Expression])
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(wkb, srid),
e)
}
}

Expand Down Expand Up @@ -375,7 +379,10 @@ case class ST_PointFromWKB(inputExpressions: Seq[Expression])
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(wkb, srid),
e)
}
}

Expand Down Expand Up @@ -413,7 +420,6 @@ case class ST_GeomFromGeoJSON(inputExpressions: Seq[Expression])
override def eval(inputRow: InternalRow): Any = {
val geomString = inputExpressions.head.eval(inputRow).asInstanceOf[UTF8String].toString
try {

val geometry = Constructors.geomFromText(geomString, FileDataSplitter.GEOJSON)
// If the user specify a bunch of attributes to go with each geometry, we need to store all of them in this geometry
if (inputExpressions.length > 1) {
Expand All @@ -422,7 +428,10 @@ case class ST_GeomFromGeoJSON(inputExpressions: Seq[Expression])
GeometrySerializer.serialize(geometry)
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(geomString),
e)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,10 @@ case class ST_IsValidDetail(children: Seq[Expression])
Seq(validDetail.valid, UTF8String.fromString(validDetail.reason), serLocation))
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, children, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(geometry),
e)
}
}

Expand Down Expand Up @@ -627,20 +630,19 @@ case class ST_MinimumBoundingRadius(inputExpressions: Seq[Expression])

override def eval(input: InternalRow): Any = {
val expr = inputExpressions(0)
val geometry = expr.toGeometry(input)

try {
val geometry = expr match {
case s: SerdeAware => s.evalWithoutSerialization(input)
case _ => expr.toGeometry(input)
}

geometry match {
case geometry: Geometry => getMinimumBoundingRadius(geometry)
case _ => null
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(geometry),
e)
}
}

Expand Down Expand Up @@ -932,22 +934,24 @@ case class ST_SubDivideExplode(children: Seq[Expression]) extends Generator with
children.validateLength(2)

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val geometryRaw = children.head
val maxVerticesRaw = children(1)
val geometry = children.head.toGeometry(input)
val maxVertices = children(1).toInt(input)
try {
geometryRaw.toGeometry(input) match {
geometry match {
case geom: Geometry =>
ArrayData.toArrayData(
Functions.subDivide(geom, maxVerticesRaw.toInt(input)).map(_.toGenericArrayData))
ArrayData.toArrayData(Functions.subDivide(geom, maxVertices).map(_.toGenericArrayData))
Functions
.subDivide(geom, maxVerticesRaw.toInt(input))
.subDivide(geom, maxVertices)
.map(_.toGenericArrayData)
.map(InternalRow(_))
case _ => new Array[InternalRow](0)
}
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, children, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(geometry, maxVertices),
e)
}
}

Expand Down Expand Up @@ -1008,8 +1012,8 @@ case class ST_MaximumInscribedCircle(children: Seq[Expression])
with CodegenFallback {

override def eval(input: InternalRow): Any = {
val geometry = children.head.toGeometry(input)
try {
val geometry = children.head.toGeometry(input)
var inscribedCircle: InscribedCircle = null
inscribedCircle = Functions.maximumInscribedCircle(geometry)

Expand All @@ -1018,7 +1022,10 @@ case class ST_MaximumInscribedCircle(children: Seq[Expression])
InternalRow.fromSeq(Seq(serCenter, serNearest, inscribedCircle.radius))
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, children, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(geometry),
e)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
*/
package org.apache.spark.sql.sedona_sql.expressions

import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, Literal}
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, DataTypes, DoubleType, IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.sedona_sql.expressions.implicits._

import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
import scala.reflect.runtime.universe.Type
import scala.reflect.runtime.universe.typeOf
Expand Down Expand Up @@ -77,27 +79,40 @@ abstract class InferredExpression(fSeq: InferrableFunction*)
override def inputTypes: Seq[AbstractDataType] = f.sparkInputTypes
override def dataType: DataType = f.sparkReturnType

private lazy val argExtractors: Array[InternalRow => Any] = f.buildExtractors(inputExpressions)
private lazy val argExtractors: Array[InternalRow => Any] = buildExtractors(inputExpressions)
private lazy val evaluator: InternalRow => Any = f.evaluatorBuilder(argExtractors)

private def findAllLiterals(expression: Expression): Seq[Literal] = {
expression match {
case lit: Literal => Seq(lit)
case _ => expression.children.flatMap(findAllLiterals)
}
}
// Remember input args to generate error messages when exceptions occur. The input arguments are
// helpful for troubleshooting the cause of errors.
private val inputArgs: ArrayBuffer[AnyRef] = ArrayBuffer.empty[AnyRef]

private def findAllLiteralsInExpressions(expressions: Seq[Expression]): Seq[String] = {
expressions.flatMap(findAllLiterals).map(_.value.toString)
private def buildExtractors(expressions: Seq[Expression]): Array[InternalRow => Any] = {
f.argExtractorBuilders
.zipAll(expressions, null, null)
.flatMap {
case (null, _) => None
case (builder, expr) =>
val extractor = builder(expr)
Some((input: InternalRow) => {
val arg = extractor(input)
inputArgs += arg.asInstanceOf[AnyRef]
arg
})
}
.toArray
}

override def eval(input: InternalRow): Any = {

try {
f.serializer(evaluator(input))
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
inputArgs.toSeq,
e)
} finally {
inputArgs.clear()
}
}

Expand All @@ -106,32 +121,32 @@ abstract class InferredExpression(fSeq: InferrableFunction*)
evaluator(input)
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(input, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
inputArgs.toSeq,
e)
} finally {
inputArgs.clear()
}
}
}

object InferredExpression {
def throwExpressionInferenceException(
input: InternalRow,
inputExpressions: Seq[Expression],
name: String,
inputArgs: Seq[Any],
e: Exception): Nothing = {
val literalsAsStrings = if (input == null) {
// In case no input row is provided, we can't extract literals from the input expressions.
inputExpressions.flatMap(findAllLiterals).map(_.value.toString)
if (e.isInstanceOf[InferredExpressionException]) {
throw e
} else {
Seq.empty[String]
}
val literalsOrInputString = literalsAsStrings.mkString(", ")
throw new InferredExpressionException(
s"Exception occurred while evaluating expression - source: [$literalsOrInputString]",
e)
}

def findAllLiterals(expression: Expression): Seq[Literal] = {
expression match {
case lit: Literal => Seq(lit)
case _ => expression.children.flatMap(findAllLiterals)
val inputsAsStrings = inputArgs.map { arg =>
val argStr = if (arg != null) arg.toString else "null"
StringUtils.abbreviate(argStr, 5000)
}
val inputsString = inputsAsStrings.mkString(", ")
throw new InferredExpressionException(
s"Exception occurred while evaluating expression $name - inputs: [$inputsString]",
e)
}
}
}
Expand Down Expand Up @@ -301,17 +316,7 @@ case class InferrableFunction(
sparkReturnType: DataType,
serializer: Any => Any,
argExtractorBuilders: Seq[Expression => InternalRow => Any],
evaluatorBuilder: Array[InternalRow => Any] => InternalRow => Any) {
def buildExtractors(expressions: Seq[Expression]): Array[InternalRow => Any] = {
argExtractorBuilders
.zipAll(expressions, null, null)
.flatMap {
case (null, _) => None
case (builder, expr) => Some(builder(expr))
}
.toArray
}
}
evaluatorBuilder: Array[InternalRow => Any] => InternalRow => Any)

object InferrableFunction {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,16 @@ abstract class ST_Predicate
if (rightArray == null) {
null
} else {
val leftGeometry = GeometrySerializer.deserialize(leftArray)
val rightGeometry = GeometrySerializer.deserialize(rightArray)
try {
val leftGeometry = GeometrySerializer.deserialize(leftArray)
val rightGeometry = GeometrySerializer.deserialize(rightArray)
evalGeom(leftGeometry, rightGeometry)
} catch {
case e: Exception =>
InferredExpression.throwExpressionInferenceException(inputRow, inputExpressions, e)
InferredExpression.throwExpressionInferenceException(
getClass.getSimpleName,
Seq(leftGeometry, rightGeometry),
e)
}
}
}
Expand Down
Loading

0 comments on commit 1737174

Please sign in to comment.