Skip to content

Commit

Permalink
[SPARK-49909][SQL] Fix the pretty name of some expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
panbingkun committed Oct 8, 2024
1 parent ef142c4 commit 6fc927f
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.collection.mutable.Growable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.UnaryLike
Expand Down Expand Up @@ -118,7 +118,8 @@ case class CollectList(

override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

override def prettyName: String = "collect_list"
override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("collect_list")

override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
new GenericArrayData(buffer.toArray)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ case class CurrentDate(timeZoneId: Option[String] = None)
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def prettyName: String = "current_date"
override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_date")
}

// scalastyle:off line.size.limit
Expand Down Expand Up @@ -329,7 +330,7 @@ case class DateAdd(startDate: Expression, days: Expression)
})
}

override def prettyName: String = "date_add"
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("date_add")

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): DateAdd = copy(startDate = newLeft, days = newRight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ object AssertTrue {
case class CurrentDatabase() extends LeafExpression with Unevaluable {
override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = false
override def prettyName: String = "current_schema"
override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_database")
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
Expand Down Expand Up @@ -307,7 +307,10 @@ case class ToCharacter(left: Expression, right: Expression)
inputTypeCheck
}
}
override def prettyName: String = "to_char"

override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("to_char")

override def nullSafeEval(decimal: Any, format: Any): Any = {
val input = decimal.asInstanceOf[Decimal]
numberFormatter.format(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
Expand Down Expand Up @@ -128,8 +128,12 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends Nondetermi
}

override def flatArguments: Iterator[Any] = Iterator(child)

override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rand")

override def sql: String = {
s"rand(${if (hideSeed) "" else child.sql})"
s"$prettyName(${if (hideSeed) "" else child.sql})"
}

override protected def withNewChildInternal(newChild: Expression): Rand = copy(child = newChild)
Expand Down

0 comments on commit 6fc927f

Please sign in to comment.