Skip to content

Commit

Permalink
[SPARK-44840][SQL] Make array_insert() 1-based for negative indexes
Browse files Browse the repository at this point in the history
In the PR, I propose to make the `array_insert` function 1-based for negative indexes. So, the maximum index is -1 should point out to the last element, and the function should insert new element at the end of the given array for the index -1.

The old behaviour can be restored via the SQL config `spark.sql.legacy.negativeIndexInArrayInsert`.

1.  To match the behaviour of functions such as `substr()` and `element_at()`.
```sql
spark-sql (default)> select element_at(array('a', 'b'), -1), substr('ab', -1);
b	b
```
2. To fix an inconsistency in `array_insert` in which positive indexes are 1-based, but negative indexes are 0-based.

Yes.

Before:
```sql
spark-sql (default)> select array_insert(array('a', 'b'), -1, 'c');
["a","c","b"]
```

After:
```sql
spark-sql (default)> select array_insert(array('a', 'b'), -1, 'c');
["a","b","c"]
```

By running the modified test suite:
```
$ build/sbt "test:testOnly *CollectionExpressionsSuite"
$ build/sbt "test:testOnly *DataFrameFunctionsSuite"
$ PYSPARK_PYTHON=python3 build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite"
```

Closes apache#42564 from MaxGekk/fix-array_insert.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
(cherry picked from commit ce50a56)
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
MaxGekk committed Aug 24, 2023
1 parent 21a86b6 commit 7310d61
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [array_insert(e#0, 0, 1) AS array_insert(e, 0, 1)#0]
Project [array_insert(e#0, 0, 1, false) AS array_insert(e, 0, 1)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
4 changes: 4 additions & 0 deletions docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ license: |
* Table of contents
{:toc}

## Upgrading from Spark SQL 3.4.1 to 3.4.2

- Since Spark 3.4.2, the `array_insert` function is 1-based for negative indexes. It inserts new element at the end of input arrays for the index -1. To restore the previous behavior, set `spark.sql.legacy.negativeIndexInArrayInsert` to `true`.

## Upgrading from Spark SQL 3.3 to 3.4

- Since Spark 3.4, INSERT INTO commands with explicit column lists comprising fewer columns than the target table will automatically add the corresponding default values for the remaining columns (or NULL for any column lacking an explicitly-assigned default value). In Spark 3.3 or earlier, these commands would have failed returning errors reporting that the number of provided columns does not match the number of columns in the target table. Note that disabling `spark.sql.defaultColumn.useNullsForMissingDefaultValues` will restore the previous behavior.
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7726,7 +7726,7 @@ def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: An
... ['data', 'pos', 'val']
... )
>>> df.select(array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect()
[Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'd', 'b', 'a'])]
[Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'b', 'd', 'a'])]
>>> df.select(array_insert(df.data, 5, 'hello').alias('data')).collect()
[Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['c', 'b', 'a', None, 'hello'])]
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4605,23 +4605,34 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
@ExpressionDescription(
usage = """
_FUNC_(x, pos, val) - Places val into index pos of array x.
Array indices start at 1, or start from the end if index is negative.
Array indices start at 1. The maximum negative index is -1 for which the function inserts
new element after the current last element.
Index above array size appends the array, or prepends the array if index is negative,
with 'null' elements.
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3, 4), 5, 5);
[1,2,3,4,5]
> SELECT _FUNC_(array(5, 3, 2, 1), -3, 4);
> SELECT _FUNC_(array(5, 4, 3, 2), -1, 1);
[5,4,3,2,1]
> SELECT _FUNC_(array(5, 3, 2, 1), -4, 4);
[5,4,3,2,1]
""",
group = "array_funcs",
since = "3.4.0")
case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr: Expression)
case class ArrayInsert(
srcArrayExpr: Expression,
posExpr: Expression,
itemExpr: Expression,
legacyNegativeIndex: Boolean)
extends TernaryExpression with ImplicitCastInputTypes with ComplexTypeMergingExpression
with QueryErrorsBase with SupportQueryContext {

def this(srcArrayExpr: Expression, posExpr: Expression, itemExpr: Expression) = {
this(srcArrayExpr, posExpr, itemExpr, SQLConf.get.legacyNegativeIndexInArrayInsert)
}

override def inputTypes: Seq[AbstractDataType] = {
(srcArrayExpr.dataType, posExpr.dataType, itemExpr.dataType) match {
case (ArrayType(e1, hasNull), e2: IntegralType, e3) if (e2 != LongType) =>
Expand Down Expand Up @@ -4683,11 +4694,12 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
val newPosExtendsArrayLeft = (posInt < 0) && (-posInt > baseArr.numElements())

if (newPosExtendsArrayLeft) {
val baseOffset = if (legacyNegativeIndex) 1 else 0
// special case- if the new position is negative but larger than the current array size
// place the new item at start of array, place the current array contents at the end
// and fill the newly created array elements inbetween with a null

val newArrayLength = -posInt + 1
val newArrayLength = -posInt + baseOffset

if (newArrayLength > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.concatArraysWithElementsExceedLimitError(newArrayLength)
Expand All @@ -4697,7 +4709,7 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:

baseArr.foreach(arrayElementType, (i, v) => {
// current position, offset by new item + new null array elements
val elementPosition = i + 1 + math.abs(posInt + baseArr.numElements())
val elementPosition = i + baseOffset + math.abs(posInt + baseArr.numElements())
newArray(elementPosition) = v
})

Expand All @@ -4706,7 +4718,7 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
return new GenericArrayData(newArray)
} else {
if (posInt < 0) {
posInt = posInt + baseArr.numElements()
posInt = posInt + baseArr.numElements() + (if (legacyNegativeIndex) 0 else 1)
} else if (posInt > 0) {
posInt = posInt - 1
}
Expand Down Expand Up @@ -4738,6 +4750,7 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
val arr = arrExpr.value
val pos = posExpr.value
val item = itemExpr.value
val baseOffset = if (legacyNegativeIndex) 1 else 0

val itemInsertionIndex = ctx.freshName("itemInsertionIndex")
val adjustedAllocIdx = ctx.freshName("adjustedAllocIdx")
Expand Down Expand Up @@ -4765,29 +4778,29 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr:
|
|if ($pos < 0 && (java.lang.Math.abs($pos) > $arr.numElements())) {
|
| $resLength = java.lang.Math.abs($pos) + 1;
| $resLength = java.lang.Math.abs($pos) + $baseOffset;
| if ($resLength > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw QueryExecutionErrors.createArrayWithElementsExceedLimitError($resLength);
| }
|
| $allocation
| for (int $i = 0; $i < $arr.numElements(); $i ++) {
| $adjustedAllocIdx = $i + 1 + java.lang.Math.abs($pos + $arr.numElements());
| $adjustedAllocIdx = $i + $baseOffset + java.lang.Math.abs($pos + $arr.numElements());
| $assignment
| }
| ${CodeGenerator.setArrayElement(
values, elementType, itemInsertionIndex, item, Some(insertedItemIsNull))}
|
| for (int $j = $pos + $arr.numElements(); $j < 0; $j ++) {
| $values.setNullAt($j + 1 + java.lang.Math.abs($pos + $arr.numElements()));
| for (int $j = ${if (legacyNegativeIndex) 0 else 1} + $pos + $arr.numElements(); $j < 0; $j ++) {
| $values.setNullAt($j + $baseOffset + java.lang.Math.abs($pos + $arr.numElements()));
| }
|
| ${ev.value} = $values;
|} else {
|
| $itemInsertionIndex = 0;
| if ($pos < 0) {
| $itemInsertionIndex = $pos + $arr.numElements();
| $itemInsertionIndex = $pos + $arr.numElements() + ${if (legacyNegativeIndex) 0 else 1};
| } else if ($pos > 0) {
| $itemInsertionIndex = $pos - 1;
| }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4185,6 +4185,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT =
buildConf("spark.sql.legacy.negativeIndexInArrayInsert")
.internal()
.doc("When set to true, restores the legacy behavior of `array_insert` for " +
"negative indexes - 0-based: the function inserts new element before the last one " +
"for the index -1. For example, `array_insert(['a', 'b'], -1, 'x')` returns " +
"`['a', 'x', 'b']`. When set to false, the -1 index points out to the last element, " +
"and the given example produces `['a', 'b', 'x']`.")
.version("3.5.0")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -5013,6 +5025,10 @@ class SQLConf extends Serializable with Logging {
def allowsTempViewCreationWithMultipleNameparts: Boolean =
getConf(SQLConf.ALLOW_TEMP_VIEW_CREATION_WITH_MULTIPLE_NAME_PARTS)

def legacyNegativeIndexInArrayInsert: Boolean = {
getConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT)
}

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2279,61 +2279,63 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val a11 = Literal.create(null, ArrayType(StringType))

// basic additions per type
checkEvaluation(ArrayInsert(a1, Literal(3), Literal(3)), Seq(1, 2, 3, 4))
checkEvaluation(new ArrayInsert(a1, Literal(3), Literal(3)), Seq(1, 2, 3, 4))
checkEvaluation(
ArrayInsert(a3, Literal.create(3, IntegerType), Literal(true)),
new ArrayInsert(a3, Literal.create(3, IntegerType), Literal(true)),
Seq[Boolean](true, false, true, true)
)
checkEvaluation(
ArrayInsert(
new ArrayInsert(
a4,
Literal(3),
Literal.create(5.asInstanceOf[Byte], ByteType)),
Seq[Byte](1, 2, 5, 3, 2))

checkEvaluation(
ArrayInsert(
new ArrayInsert(
a5,
Literal(3),
Literal.create(3.asInstanceOf[Short], ShortType)),
Seq[Short](1, 2, 3, 3, 2))

checkEvaluation(
ArrayInsert(a7, Literal(4), Literal(4.4)),
new ArrayInsert(a7, Literal(4), Literal(4.4)),
Seq[Double](1.1, 2.2, 3.3, 4.4, 2.2)
)

checkEvaluation(
ArrayInsert(a6, Literal(4), Literal(4.4F)),
new ArrayInsert(a6, Literal(4), Literal(4.4F)),
Seq(1.1F, 2.2F, 3.3F, 4.4F, 2.2F)
)
checkEvaluation(ArrayInsert(a8, Literal(3), Literal(3L)), Seq(1L, 2L, 3L, 4L))
checkEvaluation(ArrayInsert(a9, Literal(3), Literal("d")), Seq("b", "a", "d", "c"))
checkEvaluation(new ArrayInsert(a8, Literal(3), Literal(3L)), Seq(1L, 2L, 3L, 4L))
checkEvaluation(new ArrayInsert(a9, Literal(3), Literal("d")), Seq("b", "a", "d", "c"))

// index edge cases
checkEvaluation(ArrayInsert(a1, Literal(2), Literal(3)), Seq(1, 3, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(1), Literal(3)), Seq(3, 1, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(4), Literal(3)), Seq(1, 2, 4, 3))
checkEvaluation(ArrayInsert(a1, Literal(-2), Literal(3)), Seq(1, 3, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(-3), Literal(3)), Seq(3, 1, 2, 4))
checkEvaluation(ArrayInsert(a1, Literal(-4), Literal(3)), Seq(3, null, 1, 2, 4))
checkEvaluation(new ArrayInsert(a1, Literal(2), Literal(3)), Seq(1, 3, 2, 4))
checkEvaluation(new ArrayInsert(a1, Literal(1), Literal(3)), Seq(3, 1, 2, 4))
checkEvaluation(new ArrayInsert(a1, Literal(4), Literal(3)), Seq(1, 2, 4, 3))
checkEvaluation(new ArrayInsert(a1, Literal(-2), Literal(3)), Seq(1, 2, 3, 4))
checkEvaluation(new ArrayInsert(a1, Literal(-3), Literal(3)), Seq(1, 3, 2, 4))
checkEvaluation(new ArrayInsert(a1, Literal(-4), Literal(3)), Seq(3, 1, 2, 4))
checkEvaluation(new ArrayInsert(a1, Literal(-5), Literal(3)), Seq(3, null, 1, 2, 4))
checkEvaluation(
ArrayInsert(a1, Literal(10), Literal(3)),
new ArrayInsert(a1, Literal(10), Literal(3)),
Seq(1, 2, 4, null, null, null, null, null, null, 3)
)
checkEvaluation(
ArrayInsert(a1, Literal(-10), Literal(3)),
Seq(3, null, null, null, null, null, null, null, 1, 2, 4)
new ArrayInsert(a1, Literal(-10), Literal(3)),
Seq(3, null, null, null, null, null, null, 1, 2, 4)
)

// null handling
checkEvaluation(ArrayInsert(
checkEvaluation(new ArrayInsert(
a1, Literal(3), Literal.create(null, IntegerType)), Seq(1, 2, null, 4)
)
checkEvaluation(ArrayInsert(a2, Literal(3), Literal(3)), Seq(1, 2, 3, null, 4, 5, null))
checkEvaluation(ArrayInsert(a10, Literal(3), Literal("d")), Seq("b", null, "d", "a", "g", null))
checkEvaluation(ArrayInsert(a11, Literal(3), Literal("d")), null)
checkEvaluation(ArrayInsert(a10, Literal.create(null, IntegerType), Literal("d")), null)
checkEvaluation(new ArrayInsert(a2, Literal(3), Literal(3)), Seq(1, 2, 3, null, 4, 5, null))
checkEvaluation(new ArrayInsert(a10, Literal(3), Literal("d")),
Seq("b", null, "d", "a", "g", null))
checkEvaluation(new ArrayInsert(a11, Literal(3), Literal("d")), null)
checkEvaluation(new ArrayInsert(a10, Literal.create(null, IntegerType), Literal("d")), null)
}

test("Array Intersect") {
Expand Down Expand Up @@ -2754,14 +2756,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper

test("SPARK-42401: Array insert of null value (explicit)") {
val a = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false))
checkEvaluation(ArrayInsert(
checkEvaluation(new ArrayInsert(
a, Literal(2), Literal.create(null, StringType)), Seq("b", null, "a", "c")
)
}

test("SPARK-42401: Array insert of null value (implicit)") {
val a = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false))
checkEvaluation(ArrayInsert(
checkEvaluation(new ArrayInsert(
a, Literal(5), Literal.create("q", StringType)), Seq("b", "a", "c", null, "q")
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4073,7 +4073,7 @@ object functions {
* @since 3.4.0
*/
def array_insert(arr: Column, pos: Column, value: Column): Column = withExpr {
ArrayInsert(arr.expr, pos.expr, value.expr)
new ArrayInsert(arr.expr, pos.expr, value.expr)
}

/**
Expand Down
5 changes: 5 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/array.sql
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ select array_insert(array(1, 2, 3, NULL), 4, cast(NULL as INT));
select array_insert(array(2, 3, NULL, 4), 5, 5);
select array_insert(array(2, 3, NULL, 4), -5, 1);

set spark.sql.legacy.negativeIndexInArrayInsert=true;
select array_insert(array(1, 3, 4), -2, 2);
select array_insert(array(2, 3, NULL, 4), -5, 1);
set spark.sql.legacy.negativeIndexInArrayInsert=false;

-- function array_compact
select array_compact(id) from values (1) as t(id);
select array_compact(array("1", null, "2", null));
Expand Down
34 changes: 33 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ select array_insert(array(1, 3, 4), -2, 2)
-- !query schema
struct<array_insert(array(1, 3, 4), -2, 2):array<int>>
-- !query output
[1,2,3,4]
[1,3,2,4]


-- !query
Expand Down Expand Up @@ -651,6 +651,30 @@ struct<array_insert(array(2, 3, NULL, 4), 5, 5):array<int>>
[2,3,null,4,5]


-- !query
select array_insert(array(2, 3, NULL, 4), -5, 1)
-- !query schema
struct<array_insert(array(2, 3, NULL, 4), -5, 1):array<int>>
-- !query output
[1,2,3,null,4]


-- !query
set spark.sql.legacy.negativeIndexInArrayInsert=true
-- !query schema
struct<key:string,value:string>
-- !query output
spark.sql.legacy.negativeIndexInArrayInsert true


-- !query
select array_insert(array(1, 3, 4), -2, 2)
-- !query schema
struct<array_insert(array(1, 3, 4), -2, 2):array<int>>
-- !query output
[1,2,3,4]


-- !query
select array_insert(array(2, 3, NULL, 4), -5, 1)
-- !query schema
Expand All @@ -659,6 +683,14 @@ struct<array_insert(array(2, 3, NULL, 4), -5, 1):array<int>>
[1,null,2,3,null,4]


-- !query
set spark.sql.legacy.negativeIndexInArrayInsert=false
-- !query schema
struct<key:string,value:string>
-- !query output
spark.sql.legacy.negativeIndexInArrayInsert false


-- !query
select array_compact(id) from values (1) as t(id)
-- !query schema
Expand Down
Loading

0 comments on commit 7310d61

Please sign in to comment.