Skip to content

Commit

Permalink
[SPARK-44840][SQL] Make array_insert() 1-based for negative indexes
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
In the PR, I propose to make the `array_insert` function 1-based for negative indexes. So, the maximum index is -1 should point out to the last element, and the function should insert new element at the end of the given array for the index -1.

The old behaviour can be restored via the SQL config `spark.sql.legacy.negativeIndexInArrayInsert`.

### Why are the changes needed?
1.  To match the behaviour of functions such as `substr()` and `element_at()`.
```sql
spark-sql (default)> select element_at(array('a', 'b'), -1), substr('ab', -1);
b	b
```
2. To fix an inconsistency in `array_insert` in which positive indexes are 1-based, but negative indexes are 0-based.

### Does this PR introduce _any_ user-facing change?
Yes.

Before:
```sql
spark-sql (default)> select array_insert(array('a', 'b'), -1, 'c');
["a","c","b"]
```

After:
```sql
spark-sql (default)> select array_insert(array('a', 'b'), -1, 'c');
["a","b","c"]
```

### How was this patch tested?
By running the modified test suite:
```
$ build/sbt "test:testOnly *CollectionExpressionsSuite"
$ build/sbt "test:testOnly *DataFrameFunctionsSuite"
$ PYSPARK_PYTHON=python3 build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite"
```

Closes apache#42564 from MaxGekk/fix-array_insert.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
MaxGekk committed Aug 22, 2023
1 parent 24293ca commit ce50a56
Show file tree
Hide file tree
Showing 14 changed files with 218 additions and 61 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [array_insert(e#0, 0, 1) AS array_insert(e, 0, 1)#0]
Project [array_insert(e#0, 0, 1, false) AS array_insert(e, 0, 1)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [array_insert(e#0, 1, 1) AS array_prepend(e, 1)#0]
Project [array_insert(e#0, 1, 1, false) AS array_prepend(e, 1)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
1 change: 1 addition & 0 deletions docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ license: |
- Since Spark 3.5, Row's json and prettyJson methods are moved to `ToJsonUtil`.
- Since Spark 3.5, the `plan` field is moved from `AnalysisException` to `EnhancedAnalysisException`.
- Since Spark 3.5, `spark.sql.optimizer.canChangeCachedPlanOutputPartitioning` is enabled by default. To restore the previous behavior, set `spark.sql.optimizer.canChangeCachedPlanOutputPartitioning` to `false`.
- Since Spark 3.5, the `array_insert` function is 1-based for negative indexes. It inserts new element at the end of input arrays for the index -1. To restore the previous behavior, set `spark.sql.legacy.negativeIndexInArrayInsert` to `true`.

## Upgrading from Spark SQL 3.3 to 3.4

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11367,7 +11367,7 @@ def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: An
... ['data', 'pos', 'val']
... )
>>> df.select(array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect()
[Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'd', 'b', 'a'])]
[Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'b', 'd', 'a'])]
>>> df.select(array_insert(df.data, 5, 'hello').alias('data')).collect()
[Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['c', 'b', 'a', None, 'hello'])]
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ case class ArrayContains(left: Expression, right: Expression)
case class ArrayPrepend(left: Expression, right: Expression) extends RuntimeReplaceable
with ImplicitCastInputTypes with BinaryLike[Expression] with QueryErrorsBase {

override lazy val replacement: Expression = ArrayInsert(left, Literal(1), right)
override lazy val replacement: Expression = new ArrayInsert(left, Literal(1), right)

override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
Expand Down Expand Up @@ -4674,23 +4674,34 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
@ExpressionDescription(
usage = """
_FUNC_(x, pos, val) - Places val into index pos of array x.
Array indices start at 1, or start from the end if index is negative.
Array indices start at 1. The maximum negative index is -1 for which the function inserts
new element after the current last element.
Index above array size appends the array, or prepends the array if index is negative,
with 'null' elements.
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3, 4), 5, 5);
[1,2,3,4,5]
> SELECT _FUNC_(array(5, 3, 2, 1), -3, 4);
> SELECT _FUNC_(array(5, 4, 3, 2), -1, 1);
[5,4,3,2,1]
> SELECT _FUNC_(array(5, 3, 2, 1), -4, 4);
[5,4,3,2,1]
""",
group = "array_funcs",
since = "3.4.0")
case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr: Expression)
case class ArrayInsert(
srcArrayExpr: Expression,
posExpr: Expression,
itemExpr: Expression,
legacyNegativeIndex: Boolean)
extends TernaryExpression with ImplicitCastInputTypes with ComplexTypeMergingExpression
with QueryErrorsBase with SupportQueryContext {

def this(srcArrayExpr: Expression, posExpr: Expression, itemExpr: Expression) = {
this(srcArrayExpr, posExpr, itemExpr, SQLConf.get.legacyNegativeIndexInArrayInsert)
}

override def inputTypes: Seq[AbstractDataType] = {
(srcArrayExpr.dataType, posExpr.dataType, itemExpr.dataType) match {
case (ArrayType(e1, hasNull), e2: IntegralType, e3) if (e2 != LongType) =>
Expand Down Expand Up @@ -4784,11 +4795,12 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > baseArr.numElements())

if (newPosExtendsArrayLeft) {
val baseOffset = if (legacyNegativeIndex) 1 else 0
// special case- if the new position is negative but larger than the current array size
// place the new item at start of array, place the current array contents at the end
// and fill the newly created array elements inbetween with a null

val newArrayLength = -posInt + 1
val newArrayLength = -posInt + baseOffset

if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
Expand All @@ -4798,7 +4810,7 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:

baseArr.foreach(elementType, (i, v) => {
// current position, offset by new item + new null array elements
val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements())
val elementPosition = i + baseOffset + math.abs(posInt + baseArr.numElements())
newArray(elementPosition) = v
})

Expand All @@ -4807,7 +4819,7 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
new GenericArrayData(newArray)
} else {
if (posInt < 0) {
posInt = posInt + baseArr.numElements()
posInt = posInt + baseArr.numElements() + (if (legacyNegativeIndex) 0 else 1)
} else if (posInt > 0) {
posInt = posInt - 1
}
Expand Down Expand Up @@ -4883,6 +4895,7 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
|""".stripMargin
} else {
val pos = posExpr.value
val baseOffset = if (legacyNegativeIndex) 1 else 0
s"""
|int $itemInsertionIndex = 0;
|int $resLength = 0;
Expand All @@ -4895,29 +4908,29 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
|
|if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
|
| $resLength = java.lang.Math.abs($pos) + 1;
| $resLength = java.lang.Math.abs($pos) + $baseOffset;
| if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
| }
|
| $allocation
| for (int $i = 0; $i < $arr.numElements(); $i ++) {
| $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos + $arr.numElements());
| $adjustedAllocIdx = $i + $baseOffset + java.lang.Math.abs($pos + $arr.numElements());
| $assignment
| }
| ${CodeGenerator.setArrayElement(
values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))}
|
| for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
| $values.setNullAt($j + 1 + java.lang.Math.abs($pos + $arr.numElements()));
| for (int $j = ${if (legacyNegativeIndex) 0 else 1} + $pos + $arr.numElements(); $j < 0; $j ++) {
| $values.setNullAt($j + $baseOffset + java.lang.Math.abs($pos + $arr.numElements()));
| }
|
| ${ev.value} = $values;
|} else {
|
| $itemInsertionIndex = 0;
| if ($pos < 0) {
| $itemInsertionIndex = $pos + $arr.numElements();
| $itemInsertionIndex = $pos + $arr.numElements() + ${if (legacyNegativeIndex) 0 else 1};
| } else if ($pos > 0) {
| $itemInsertionIndex = $pos - 1;
| }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4378,6 +4378,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT =
buildConf("spark.sql.legacy.negativeIndexInArrayInsert")
.internal()
.doc("When set to true, restores the legacy behavior of `array_insert` for " +
"negative indexes - 0-based: the function inserts new element before the last one " +
"for the index -1. For example, `array_insert(['a', 'b'], -1, 'x')` returns " +
"`['a', 'x', 'b']`. When set to false, the -1 index points out to the last element, " +
"and the given example produces `['a', 'b', 'x']`.")
.version("3.5.0")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -5231,6 +5243,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def usePartitionEvaluator: Boolean = getConf(SQLConf.USE_PARTITION_EVALUATOR)

def legacyNegativeIndexInArrayInsert: Boolean = {
getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT)
}

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2279,61 +2279,63 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val a11 = Literal.create(null, ArrayType(StringType))

// basic additions per type
checkEvaluation(ArrayInsert(a1, Literal(3), Literal(3)), Seq(1, 2, 3, 4))
checkEvaluation(new ArrayInsert(a1, Literal(3), Literal(3)), Seq(1, 2, 3, 4))
checkEvaluation(
ArrayInsert(a3, Literal.create(3, IntegerType), Literal(true)),
new ArrayInsert(a3, Literal.create(3, IntegerType), Literal(true)),
Seq[Boolean](true, false, true, true)
)
checkEvaluation(
ArrayInsert(
new ArrayInsert(
a4,
Literal(3),
Literal.create(5.asInstanceOf[Byte], ByteType)),
Seq[Byte](1, 2, 5, 3, 2))

checkEvaluation(
ArrayInsert(
new ArrayInsert(
a5,
Literal(3),
Literal.create(3.asInstanceOf[Short], ShortType)),
Seq[Short](1, 2, 3, 3, 2))

checkEvaluation(
ArrayInsert(a7, Literal(4), Literal(4.4)),
new ArrayInsert(a7, Literal(4), Literal(4.4)),
Seq[Double](1.1, 2.2, 3.3, 4.4, 2.2)
)

checkEvaluation(
ArrayInsert(a6, Literal(4), Literal(4.4F)),
new ArrayInsert(a6, Literal(4), Literal(4.4F)),
Seq(1.1F, 2.2F, 3.3F, 4.4F, 2.2F)
)
checkEvaluation(ArrayInsert(a8, Literal(3), Literal(3L)), Seq(1L, 2L, 3L, 4L))
checkEvaluation(ArrayInsert(a9, Literal(3), Literal("d")), Seq("b", "a", "d", "c"))
checkEvaluation(new ArrayInsert(a8, Literal(3), Literal(3L)), Seq(1L, 2L, 3L, 4L))
checkEvaluation(new ArrayInsert(a9, Literal(3), Literal("d")), Seq("b", "a", "d", "c"))

// index edge cases
checkEvaluation(ArrayInsert(a1, Literal(2), Literal(3)), Seq(1, 3, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(1), Literal(3)), Seq(3, 1, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(4), Literal(3)), Seq(1, 2, 4, 3))
checkEvaluation(ArrayInsert(a1, Literal(-2), Literal(3)), Seq(1, 3, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(-3), Literal(3)), Seq(3, 1, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(-4), Literal(3)), Seq(3, null, 1, 2, 4))
checkEvaluation(new ArrayInsert(a1, Literal(2), Literal(3)), Seq(1, 3, 2, 4))
checkEvaluation(new ArrayInsert(a1, Literal(1), Literal(3)), Seq(3, 1, 2, 4))
checkEvaluation(new ArrayInsert(a1, Literal(4), Literal(3)), Seq(1, 2, 4, 3))
checkEvaluation(new ArrayInsert(a1, Literal(-2), Literal(3)), Seq(1, 2, 3, 4))
checkEvaluation(new ArrayInsert(a1, Literal(-3), Literal(3)), Seq(1, 3, 2, 4))
checkEvaluation(new ArrayInsert(a1, Literal(-4), Literal(3)), Seq(3, 1, 2, 4))
checkEvaluation(new ArrayInsert(a1, Literal(-5), Literal(3)), Seq(3, null, 1, 2, 4))
checkEvaluation(
ArrayInsert(a1, Literal(10), Literal(3)),
new ArrayInsert(a1, Literal(10), Literal(3)),
Seq(1, 2, 4, null, null, null, null, null, null, 3)
)
checkEvaluation(
ArrayInsert(a1, Literal(-10), Literal(3)),
Seq(3, null, null, null, null, null, null, null, 1, 2, 4)
new ArrayInsert(a1, Literal(-10), Literal(3)),
Seq(3, null, null, null, null, null, null, 1, 2, 4)
)

// null handling
checkEvaluation(ArrayInsert(
checkEvaluation(new ArrayInsert(
a1, Literal(3), Literal.create(null, IntegerType)), Seq(1, 2, null, 4)
)
checkEvaluation(ArrayInsert(a2, Literal(3), Literal(3)), Seq(1, 2, 3, null, 4, 5, null))
checkEvaluation(ArrayInsert(a10, Literal(3), Literal("d")), Seq("b", null, "d", "a", "g", null))
checkEvaluation(ArrayInsert(a11, Literal(3), Literal("d")), null)
checkEvaluation(ArrayInsert(a10, Literal.create(null, IntegerType), Literal("d")), null)
checkEvaluation(new ArrayInsert(a2, Literal(3), Literal(3)), Seq(1, 2, 3, null, 4, 5, null))
checkEvaluation(new ArrayInsert(a10, Literal(3), Literal("d")),
Seq("b", null, "d", "a", "g", null))
checkEvaluation(new ArrayInsert(a11, Literal(3), Literal("d")), null)
checkEvaluation(new ArrayInsert(a10, Literal.create(null, IntegerType), Literal("d")), null)
}

test("Array Intersect") {
Expand Down Expand Up @@ -2754,14 +2756,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper

test("SPARK-42401: Array insert of null value (explicit)") {
val a = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false))
checkEvaluation(ArrayInsert(
checkEvaluation(new ArrayInsert(
a, Literal(2), Literal.create(null, StringType)), Seq("b", null, "a", "c")
)
}

test("SPARK-42401: Array insert of null value (implicit)") {
val a = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false))
checkEvaluation(ArrayInsert(
checkEvaluation(new ArrayInsert(
a, Literal(5), Literal.create("q", StringType)), Seq("b", "a", "c", null, "q")
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6155,7 +6155,7 @@ object functions {
* @since 3.4.0
*/
def array_insert(arr: Column, pos: Column, value: Column): Column = withExpr {
ArrayInsert(arr.expr, pos.expr, value.expr)
new ArrayInsert(arr.expr, pos.expr, value.expr)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,28 +447,28 @@ Project [get(array(1, 2, 3), -1) AS get(array(1, 2, 3), -1)#x]
-- !query
select array_insert(array(1, 2, 3), 3, 4)
-- !query analysis
Project [array_insert(array(1, 2, 3), 3, 4) AS array_insert(array(1, 2, 3), 3, 4)#x]
Project [array_insert(array(1, 2, 3), 3, 4, false) AS array_insert(array(1, 2, 3), 3, 4)#x]
+- OneRowRelation


-- !query
select array_insert(array(2, 3, 4), 0, 1)
-- !query analysis
Project [array_insert(array(2, 3, 4), 0, 1) AS array_insert(array(2, 3, 4), 0, 1)#x]
Project [array_insert(array(2, 3, 4), 0, 1, false) AS array_insert(array(2, 3, 4), 0, 1)#x]
+- OneRowRelation


-- !query
select array_insert(array(2, 3, 4), 1, 1)
-- !query analysis
Project [array_insert(array(2, 3, 4), 1, 1) AS array_insert(array(2, 3, 4), 1, 1)#x]
Project [array_insert(array(2, 3, 4), 1, 1, false) AS array_insert(array(2, 3, 4), 1, 1)#x]
+- OneRowRelation


-- !query
select array_insert(array(1, 3, 4), -2, 2)
-- !query analysis
Project [array_insert(array(1, 3, 4), -2, 2) AS array_insert(array(1, 3, 4), -2, 2)#x]
Project [array_insert(array(1, 3, 4), -2, 2, false) AS array_insert(array(1, 3, 4), -2, 2)#x]
+- OneRowRelation


Expand Down Expand Up @@ -499,38 +499,64 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
-- !query
select array_insert(cast(NULL as ARRAY<INT>), 1, 1)
-- !query analysis
Project [array_insert(cast(null as array<int>), 1, 1) AS array_insert(NULL, 1, 1)#x]
Project [array_insert(cast(null as array<int>), 1, 1, false) AS array_insert(NULL, 1, 1)#x]
+- OneRowRelation


-- !query
select array_insert(array(1, 2, 3, NULL), cast(NULL as INT), 4)
-- !query analysis
Project [array_insert(array(1, 2, 3, cast(null as int)), cast(null as int), 4) AS array_insert(array(1, 2, 3, NULL), CAST(NULL AS INT), 4)#x]
Project [array_insert(array(1, 2, 3, cast(null as int)), cast(null as int), 4, false) AS array_insert(array(1, 2, 3, NULL), CAST(NULL AS INT), 4)#x]
+- OneRowRelation


-- !query
select array_insert(array(1, 2, 3, NULL), 4, cast(NULL as INT))
-- !query analysis
Project [array_insert(array(1, 2, 3, cast(null as int)), 4, cast(null as int)) AS array_insert(array(1, 2, 3, NULL), 4, CAST(NULL AS INT))#x]
Project [array_insert(array(1, 2, 3, cast(null as int)), 4, cast(null as int), false) AS array_insert(array(1, 2, 3, NULL), 4, CAST(NULL AS INT))#x]
+- OneRowRelation


-- !query
select array_insert(array(2, 3, NULL, 4), 5, 5)
-- !query analysis
Project [array_insert(array(2, 3, cast(null as int), 4), 5, 5) AS array_insert(array(2, 3, NULL, 4), 5, 5)#x]
Project [array_insert(array(2, 3, cast(null as int), 4), 5, 5, false) AS array_insert(array(2, 3, NULL, 4), 5, 5)#x]
+- OneRowRelation


-- !query
select array_insert(array(2, 3, NULL, 4), -5, 1)
-- !query analysis
Project [array_insert(array(2, 3, cast(null as int), 4), -5, 1) AS array_insert(array(2, 3, NULL, 4), -5, 1)#x]
Project [array_insert(array(2, 3, cast(null as int), 4), -5, 1, false) AS array_insert(array(2, 3, NULL, 4), -5, 1)#x]
+- OneRowRelation


-- !query
set spark.sql.legacy.negativeIndexInArrayInsert=true
-- !query analysis
SetCommand (spark.sql.legacy.negativeIndexInArrayInsert,Some(true))


-- !query
select array_insert(array(1, 3, 4), -2, 2)
-- !query analysis
Project [array_insert(array(1, 3, 4), -2, 2, true) AS array_insert(array(1, 3, 4), -2, 2)#x]
+- OneRowRelation


-- !query
select array_insert(array(2, 3, NULL, 4), -5, 1)
-- !query analysis
Project [array_insert(array(2, 3, cast(null as int), 4), -5, 1, true) AS array_insert(array(2, 3, NULL, 4), -5, 1)#x]
+- OneRowRelation


-- !query
set spark.sql.legacy.negativeIndexInArrayInsert=false
-- !query analysis
SetCommand (spark.sql.legacy.negativeIndexInArrayInsert,Some(false))


-- !query
select array_compact(id) from values (1) as t(id)
-- !query analysis
Expand Down
Loading

0 comments on commit ce50a56

Please sign in to comment.