diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index f93e5736de401..7f45c96f85455 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.mutable +import scala.collection.{mutable, GenTraversableOnce} import scala.collection.mutable.ArrayBuffer object ExpressionSet { @@ -67,6 +67,12 @@ class ExpressionSet protected( newSet } + override def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = { + val newSet = new ExpressionSet(baseSet.clone(), originals.clone()) + elems.foreach(newSet.add) + newSet + } + override def -(elem: Expression): ExpressionSet = { val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index d617ad540d5ff..e5d8771564ad2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -210,4 +210,14 @@ class ExpressionSetSuite extends SparkFunSuite { assert((initialSet - (aLower + 1)).size == 0) } + + test("add multiple elements to set") { + val initialSet = ExpressionSet(aUpper + 1 :: Nil) + val setToAddWithSameExpression = ExpressionSet(aUpper + 1 :: aUpper + 2 :: Nil) + val setToAddWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil) + + assert((initialSet ++ setToAddWithSameExpression).size == 2) + assert((initialSet ++ setToAddWithOutSameExpression).size == 3) + } + }