Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49962][SQL] Simplify AbstractStringTypes class hierarchy #48459

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1091,15 +1091,13 @@ public static int collationNameToId(String collationName) throws SparkException
return Collation.CollationSpec.collationNameToId(collationName);
}

/**
* Returns whether the ICU collation is not Case Sensitive Accent Insensitive
* for the given collation id.
* This method is used in expressions which do not support CS_AI collations.
*/
public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) {
public static boolean isCaseInsensitive(int collationId) {
return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity ==
Collation.CollationSpecICU.CaseSensitivity.CS &&
Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity ==
Collation.CollationSpecICU.CaseSensitivity.CI;
}

public static boolean isAccentInsensitive(int collationId) {
return Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity ==
Collation.CollationSpecICU.AccentSensitivity.AI;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,35 @@ import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}

/**
* AbstractStringType is an abstract class for StringType with collation support. As every type of
* collation can support trim specifier this class is parametrized with it.
* AbstractStringType is an abstract class for StringType with collation support.
*/
abstract class AbstractStringType(private[sql] val supportsTrimCollation: Boolean = false)
abstract class AbstractStringType(supportsTrimCollation: Boolean = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you remove private[sql] val? let's keep it private unless there's a need

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it need to private? Looks like other classes inheriting from AbstractDataType like AbstractMapType don't have their members private either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we keep it private to SQL in all of them? no need to expose stuff unless there's a need to expose stuff

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sql/internal is already a private package so we won;t need to mark it as private[sql]. We should add a package.scala and document so, see https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
override private[sql] def simpleString: String = "string"
private[sql] def canUseTrimCollation(other: DataType): Boolean =
supportsTrimCollation || !other.asInstanceOf[StringType].usesTrimCollation

override private[sql] def acceptsType(other: DataType): Boolean = {
other match {
case st: StringType =>
canUseTrimCollation(st) && acceptsStringType(st)
case _ =>
false
}
}
stefankandic marked this conversation as resolved.
Show resolved Hide resolved
def acceptsStringType(other: StringType): Boolean

private[sql] def canUseTrimCollation(other: StringType): Boolean =
supportsTrimCollation || !other.usesTrimCollation
stefankandic marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Use StringTypeBinary for expressions supporting only binary collation.
* Used for expressions supporting only binary collation.
*/
case class StringTypeBinary(override val supportsTrimCollation: Boolean = false)
case class StringTypeBinary(supportsTrimCollation: Boolean)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality &&
canUseTrimCollation(other)

override def acceptsStringType(other: StringType): Boolean =
other.supportsBinaryEquality
}

object StringTypeBinary extends StringTypeBinary(false) {
Expand All @@ -49,13 +59,13 @@ object StringTypeBinary extends StringTypeBinary(false) {
}

/**
* Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation.
* Used for expressions supporting only binary and lowercase collation.
*/
case class StringTypeBinaryLcase(override val supportsTrimCollation: Boolean = false)
case class StringTypeBinaryLcase(supportsTrimCollation: Boolean)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality ||
other.asInstanceOf[StringType].isUTF8LcaseCollation) && canUseTrimCollation(other)

override def acceptsStringType(other: StringType): Boolean =
other.supportsBinaryEquality || other.isUTF8LcaseCollation
}

object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) {
Expand All @@ -65,31 +75,42 @@ object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) {
}

/**
* Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary
* and ICU) but limited to using case and accent sensitivity specifiers.
* Used for expressions supporting collation types with optional
* case, accent, and trim sensitivity specifiers.
*
* Case and accent sensitivity specifiers are supported by default.
*/
case class StringTypeWithCaseAccentSensitivity(
override val supportsTrimCollation: Boolean = false)
case class StringTypeWithCollation(
supportsTrimCollation: Boolean,
supportsCaseSpecifier: Boolean,
supportsAccentSpecifier: Boolean)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && canUseTrimCollation(other)

override def acceptsStringType(other: StringType): Boolean = {
(supportsCaseSpecifier || !other.isCaseInsensitive) &&
(supportsAccentSpecifier || !other.isAccentInsensitive)
}
}

object StringTypeWithCaseAccentSensitivity extends StringTypeWithCaseAccentSensitivity(false) {
def apply(supportsTrimCollation: Boolean): StringTypeWithCaseAccentSensitivity = {
new StringTypeWithCaseAccentSensitivity(supportsTrimCollation)
object StringTypeWithCollation extends StringTypeWithCollation(false, true, true) {
def apply(
supportsTrimCollation: Boolean = false,
supportsCaseSpecifier: Boolean = true,
supportsAccentSpecifier: Boolean = true): StringTypeWithCollation = {
new StringTypeWithCollation(
supportsTrimCollation, supportsCaseSpecifier, supportsAccentSpecifier)
}
}

/**
* Use StringTypeNonCSAICollation for expressions supporting all possible collation types except
* CS_AI collation types.
* Used for expressions supporting all possible collation types except
* those that are case-sensitive but accent insensitive (CS_AI).
*/
case class StringTypeNonCSAICollation(override val supportsTrimCollation: Boolean = false)
case class StringTypeNonCSAICollation(supportsTrimCollation: Boolean)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI &&
canUseTrimCollation(other)

override def acceptsStringType(other: StringType): Boolean =
other.isCaseInsensitive || !other.isAccentInsensitive
}

object StringTypeNonCSAICollation extends StringTypeNonCSAICollation(false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
private[sql] def supportsLowercaseEquality: Boolean =
CollationFactory.fetchCollation(collationId).supportsLowercaseEquality

private[sql] def isNonCSAI: Boolean =
!CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId)
private[sql] def isCaseInsensitive: Boolean =
CollationFactory.isCaseInsensitive(collationId)

private[sql] def isAccentInsensitive: Boolean =
CollationFactory.isAccentInsensitive(collationId)

private[sql] def usesTrimCollation: Boolean =
CollationFactory.usesTrimCollation(collationId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType,
StringTypeWithCaseAccentSensitivity}
StringTypeWithCollation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.UpCastRule.numericPrecedence

Expand Down Expand Up @@ -439,7 +439,7 @@ abstract class TypeCoercionBase {
}

case aj @ ArrayJoin(arr, d, nr)
if !AbstractArrayType(StringTypeWithCaseAccentSensitivity).acceptsType(arr.dataType) &&
if !AbstractArrayType(StringTypeWithCollation).acceptsType(arr.dataType) &&
ArrayType.acceptsType(arr.dataType) =>
val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
implicitCast(arr, ArrayType(StringType, containsNull)) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -84,7 +84,7 @@ case class CallMethodViaReflection(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("class"),
"inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity),
"inputType" -> toSQLType(StringTypeWithCollation),
"inputExpr" -> toSQLExpr(children.head)
)
)
Expand All @@ -97,7 +97,7 @@ case class CallMethodViaReflection(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("method"),
"inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity),
"inputType" -> toSQLType(StringTypeWithCollation),
"inputExpr" -> toSQLExpr(children(1))
)
)
Expand All @@ -115,7 +115,7 @@ case class CallMethodViaReflection(
"requiredType" -> toSQLType(
TypeCollection(BooleanType, ByteType, ShortType,
IntegerType, LongType, FloatType, DoubleType,
StringTypeWithCaseAccentSensitivity)),
StringTypeWithCollation)),
"inputSql" -> toSQLExpr(e),
"inputType" -> toSQLType(e.dataType))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
Seq(StringTypeWithCollation(supportsTrimCollation = true))
override def dataType: DataType = BinaryType

final lazy val collationId: Int = expr.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseAccentSensitivity}
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCollation}
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -61,7 +61,7 @@ object ExprUtils extends QueryErrorsBase {

def convertToMapData(exp: Expression): Map[String, String] = exp match {
case m: CreateMap
if AbstractMapType(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity)
if AbstractMapType(StringTypeWithCollation, StringTypeWithCollation)
.acceptsType(m.dataType) =>
val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData]
ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression,
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, IntegerType, LongType, StringType, TypeCollection}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -109,7 +109,7 @@ case class HllSketchAgg(
TypeCollection(
IntegerType,
LongType,
StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true),
StringTypeWithCollation(supportsTrimCollation = true),
BinaryType),
IntegerType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._

// scalastyle:off line.contains.tab
Expand Down Expand Up @@ -78,7 +78,7 @@ case class Collate(child: Expression, collationName: String)
private val collationId = CollationFactory.collationNameToId(collationName)
override def dataType: DataType = StringType(collationId)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
Seq(StringTypeWithCollation(supportsTrimCollation = true))

override protected def withNewChildInternal(
newChild: Expression): Expression = copy(newChild)
Expand Down Expand Up @@ -117,5 +117,5 @@ case class Collation(child: Expression)
Literal.create(collationName, SQLConf.get.defaultStringType)
}
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
Seq(StringTypeWithCollation(supportsTrimCollation = true))
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity}
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCollation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SQLOpenHashSet
import org.apache.spark.unsafe.UTF8StringBuilder
Expand Down Expand Up @@ -1349,7 +1349,7 @@ case class Reverse(child: Expression)

// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, ArrayType))
Seq(TypeCollection(StringTypeWithCollation, ArrayType))

override def dataType: DataType = child.dataType

Expand Down Expand Up @@ -2135,12 +2135,12 @@ case class ArrayJoin(
this(array, delimiter, Some(nullReplacement))

override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity),
StringTypeWithCaseAccentSensitivity,
StringTypeWithCaseAccentSensitivity)
Seq(AbstractArrayType(StringTypeWithCollation),
StringTypeWithCollation,
StringTypeWithCollation)
} else {
Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity),
StringTypeWithCaseAccentSensitivity)
Seq(AbstractArrayType(StringTypeWithCollation),
StringTypeWithCollation)
}

override def children: Seq[Expression] = if (nullReplacement.isDefined) {
Expand Down Expand Up @@ -2861,7 +2861,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
with QueryErrorsBase {

private def allowedTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity, BinaryType, ArrayType)
Seq(StringTypeWithCollation, BinaryType, ArrayType)

final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -147,7 +147,7 @@ case class CsvToStructs(
converter(parser.parse(csv))
}

override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil
override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil

override def prettyName: String = "from_csv"

Expand Down
Loading