Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25829][SQL][FOLLOWUP] Refactor MapConcat in order to check properly the limit size #23217

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -554,13 +554,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
return null
}

val numElements = maps.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " +
s"elements due to exceeding the map size limit " +
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}

for (map <- maps) {
mapBuilder.putAll(map.keyArray(), map.valueArray())
}
Expand All @@ -569,8 +562,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val mapCodes = children.map(_.genCode(ctx))
val keyType = dataType.keyType
val valueType = dataType.valueType
val argsName = ctx.freshName("args")
val hasNullName = ctx.freshName("hasNull")
val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder)
Expand Down Expand Up @@ -610,41 +601,12 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
)

val idxName = ctx.freshName("idx")
val numElementsName = ctx.freshName("numElems")
val finKeysName = ctx.freshName("finalKeys")
val finValsName = ctx.freshName("finalValues")

val keyConcat = genCodeForArrays(ctx, keyType, false)

val valueConcat =
if (valueType.sameType(keyType) &&
!(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) {
keyConcat
} else {
genCodeForArrays(ctx, valueType, dataType.valueContainsNull)
}

val keyArgsName = ctx.freshName("keyArgs")
val valArgsName = ctx.freshName("valArgs")

val mapMerge =
s"""
|ArrayData[] $keyArgsName = new ArrayData[${mapCodes.size}];
|ArrayData[] $valArgsName = new ArrayData[${mapCodes.size}];
|long $numElementsName = 0;
|for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
| $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
| $valArgsName[$idxName] = $argsName[$idxName].valueArray();
| $numElementsName += $argsName[$idxName].numElements();
| $builderTerm.putAll($argsName[$idxName].keyArray(), $argsName[$idxName].valueArray());
|}
|if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful attempt to concat maps with " +
| $numElementsName + " elements due to exceeding the map size limit " +
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|}
|ArrayData $finKeysName = $keyConcat($keyArgsName, (int) $numElementsName);
|ArrayData $finValsName = $valueConcat($valArgsName, (int) $numElementsName);
|${ev.value} = $builderTerm.from($finKeysName, $finValsName);
|${ev.value} = $builderTerm.build();
""".stripMargin

ev.copy(
Expand All @@ -660,41 +622,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
""".stripMargin)
}

private def genCodeForArrays(
ctx: CodegenContext,
elementType: DataType,
checkForNull: Boolean): String = {
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")
val argsName = ctx.freshName("args")
val numElemName = ctx.freshName("numElements")
val y = ctx.freshName("y")
val z = ctx.freshName("z")

val allocation = CodeGenerator.createArrayData(
arrayData, elementType, numElemName, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull)

val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
| $allocation
| int $counter = 0;
| for (int $y = 0; $y < ${children.length}; $y++) {
| for (int $z = 0; $z < $argsName[$y].numElements(); $z++) {
| $assignment
| $counter++;
| }
| }
| return $arrayData;
|}
""".stripMargin

ctx.addNewFunction(concat, concatDef)
}

override def prettyName: String = "map_concat"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods

/**
* A builder of [[ArrayBasedMapData]], which fails if a null map key is detected, and removes
Expand Down Expand Up @@ -54,6 +55,10 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria

val index = keyToIndex.getOrDefault(key, -1)
if (index == -1) {
if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful attempt to build maps with $size elements " +
s"due to exceeding the map size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
keyToIndex.put(key, values.length)
keys.append(key)
values.append(value)
Expand Down Expand Up @@ -117,4 +122,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
build()
}
}

/**
* Returns the current size of the map which is going to be produced by the current builder.
*/
def size: Int = keys.size
}