From 3c34183d4fef8e7398037d5eb0b336efba5efa54 Mon Sep 17 00:00:00 2001 From: Alfonso Roa Date: Thu, 25 Aug 2022 19:06:54 +0200 Subject: [PATCH] Empty arrays returns the expected datatype --- .../main/scala/doric/syntax/ArrayColumns.scala | 12 +++++++++--- .../scala/doric/syntax/ArrayColumnsSpec.scala | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/doric/syntax/ArrayColumns.scala b/core/src/main/scala/doric/syntax/ArrayColumns.scala index 3ba2e43bb..ec83c79bd 100644 --- a/core/src/main/scala/doric/syntax/ArrayColumns.scala +++ b/core/src/main/scala/doric/syntax/ArrayColumns.scala @@ -2,10 +2,11 @@ package doric package syntax import scala.language.higherKinds +import scala.reflect.ClassTag import cats.data.Kleisli import cats.implicits._ -import doric.types.CollectionType +import doric.types.{CollectionType, LiteralSparkType, SparkType} import org.apache.spark.sql.{Column, Dataset, functions => f} import org.apache.spark.sql.catalyst.expressions._ @@ -58,8 +59,13 @@ private[syntax] trait ArrayColumns { * @see org.apache.spark.sql.functions.array * @todo scaladoc link (issue #135) */ - def array[T](cols: DoricColumn[T]*): ArrayColumn[T] = - cols.toList.traverse(_.elem).map(f.array(_: _*)).toDC + def array[T: SparkType: ClassTag]( + cols: DoricColumn[T]* + )(implicit lt: LiteralSparkType[Array[T]]): ArrayColumn[T] = + if (cols.nonEmpty) + cols.toList.traverse(_.elem).map(f.array(_: _*)).toDC + else + lit(Array.empty[T]) /** * Creates a new list column. The input columns must all have the same data type. diff --git a/core/src/test/scala/doric/syntax/ArrayColumnsSpec.scala b/core/src/test/scala/doric/syntax/ArrayColumnsSpec.scala index ade51a550..64b0a36cd 100644 --- a/core/src/test/scala/doric/syntax/ArrayColumnsSpec.scala +++ b/core/src/test/scala/doric/syntax/ArrayColumnsSpec.scala @@ -2,6 +2,7 @@ package doric package syntax import doric.sem.{ChildColumnNotFound, ColumnTypeError, DoricMultiError, SparkErrorWrapper} +import doric.types.SparkType import org.apache.spark.sql.{Row, functions => f} import org.apache.spark.sql.types.{IntegerType, LongType, StringType} @@ -300,6 +301,20 @@ class ArrayColumnsSpec extends DoricTestElements { List(Some(Array("a", "b"))) ) } + + it("should work be of the expected type when is empty") { + val df = spark + .range(5) + .select( + array[Long]().as("l"), + array[String]().as("s"), + array[(String, String)]().as("r") + ) + + df("l").expr.dataType === SparkType[Array[Long]].dataType + df("s").expr.dataType === SparkType[Array[String]].dataType + df("r").expr.dataType === SparkType[Array[(String, String)]].dataType + } } describe("list doric function") {