Skip to content

Commit

Permalink
[SPARK-24636][SQL] Type coercion of arrays for array_join function
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
Presto's implementation accepts arbitrary arrays of primitive types as an input:

```
presto> SELECT array_join(ARRAY [1, 2, 3], ', ');
_col0
---------
1, 2, 3
(1 row)
```

This PR proposes to implement a type coercion rule for ```array_join``` function that converts arrays of primitive as well as non-primitive types to arrays of string.

## How was this patch tested?

New test cases add into:
- sql-tests/inputs/typeCoercion/native/arrayJoin.sql
- DataFrameFunctionsSuite.scala

Author: Marek Novotny <mn.mikke@gmail.com>

Closes #21620 from mn-mikke/SPARK-24636.
  • Loading branch information
mn-mikke authored and HyukjinKwon committed Jun 26, 2018
1 parent c7967c6 commit e07aee2
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,14 @@ object TypeCoercion {
case None => c
}

case aj @ ArrayJoin(arr, d, nr) if !ArrayType(StringType).acceptsType(arr.dataType) &&
ArrayType.acceptsType(arr.dataType) =>
val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
ImplicitTypeCasts.implicitCast(arr, ArrayType(StringType, containsNull)) match {
case Some(castedArr) => ArrayJoin(castedArr, d, nr)
case None => aj
}

case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,7 @@ case class ArrayJoin(

override def dataType: DataType = StringType

override def prettyName: String = "array_join"
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
SELECT array_join(array(true, false), ', ');
SELECT array_join(array(2Y, 1Y), ', ');
SELECT array_join(array(2S, 1S), ', ');
SELECT array_join(array(2, 1), ', ');
SELECT array_join(array(2L, 1L), ', ');
SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ');
SELECT array_join(array(2.0D, 1.0D), ', ');
SELECT array_join(array(float(2.0), float(1.0)), ', ');
SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ');
SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', ');
SELECT array_join(array('a', 'b'), ', ');
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 11


-- !query 0
SELECT array_join(array(true, false), ', ')
-- !query 0 schema
struct<array_join(array(true, false), , ):string>
-- !query 0 output
true, false


-- !query 1
SELECT array_join(array(2Y, 1Y), ', ')
-- !query 1 schema
struct<array_join(array(2, 1), , ):string>
-- !query 1 output
2, 1


-- !query 2
SELECT array_join(array(2S, 1S), ', ')
-- !query 2 schema
struct<array_join(array(2, 1), , ):string>
-- !query 2 output
2, 1


-- !query 3
SELECT array_join(array(2, 1), ', ')
-- !query 3 schema
struct<array_join(array(2, 1), , ):string>
-- !query 3 output
2, 1


-- !query 4
SELECT array_join(array(2L, 1L), ', ')
-- !query 4 schema
struct<array_join(array(2, 1), , ):string>
-- !query 4 output
2, 1


-- !query 5
SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ')
-- !query 5 schema
struct<array_join(array(9223372036854775809, 9223372036854775808), , ):string>
-- !query 5 output
9223372036854775809, 9223372036854775808


-- !query 6
SELECT array_join(array(2.0D, 1.0D), ', ')
-- !query 6 schema
struct<array_join(array(2.0, 1.0), , ):string>
-- !query 6 output
2.0, 1.0


-- !query 7
SELECT array_join(array(float(2.0), float(1.0)), ', ')
-- !query 7 schema
struct<array_join(array(CAST(2.0 AS FLOAT), CAST(1.0 AS FLOAT)), , ):string>
-- !query 7 output
2.0, 1.0


-- !query 8
SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ')
-- !query 8 schema
struct<array_join(array(DATE '2016-03-14', DATE '2016-03-13'), , ):string>
-- !query 8 output
2016-03-14, 2016-03-13


-- !query 9
SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', ')
-- !query 9 schema
struct<array_join(array(TIMESTAMP('2016-11-15 20:54:00.0'), TIMESTAMP('2016-11-12 20:54:00.0')), , ):string>
-- !query 9 output
2016-11-15 20:54:00, 2016-11-12 20:54:00


-- !query 10
SELECT array_join(array('a', 'b'), ', ')
-- !query 10 schema
struct<array_join(array(a, b), , ):string>
-- !query 10 output
a, b
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer(
df.selectExpr("array_join(x, delimiter, 'NULL')"),
Seq(Row("a,b"), Row("a,NULL,b"), Row("")))

val idf = Seq(Seq(1, 2, 3)).toDF("x")

checkAnswer(
idf.select(array_join(idf("x"), ", ")),
Seq(Row("1, 2, 3"))
)
checkAnswer(
idf.selectExpr("array_join(x, ', ')"),
Seq(Row("1, 2, 3"))
)
intercept[AnalysisException] {
idf.selectExpr("array_join(x, 1)")
}
intercept[AnalysisException] {
idf.selectExpr("array_join(x, ', ', 1)")
}
}

test("array_min function") {
Expand Down

0 comments on commit e07aee2

Please sign in to comment.