From 55fd70381ba7e66be478443f7991bd0df6337fba Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 5 Dec 2023 12:31:56 +0800 Subject: [PATCH 1/4] init init nit --- python/pyspark/sql/dataframe.py | 13 ++++++++- .../tests/connect/test_parity_dataframe.py | 5 ++++ python/pyspark/sql/tests/test_dataframe.py | 9 +++++++ .../scala/org/apache/spark/sql/Dataset.scala | 27 +++++++++++++------ 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5211d874ba33e..1419d1f3cb635 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -6272,7 +6272,18 @@ def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame": message_parameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__}, ) - return DataFrame(self._jdf.withColumnsRenamed(colsMap), self.sparkSession) + col_names: List[str] = [] + new_col_names: List[str] = [] + for k, v in colsMap.items(): + col_names.append(k) + new_col_names.append(v) + + return DataFrame( + self._jdf.withColumnsRenamed( + _to_seq(self._sc, col_names), _to_seq(self._sc, new_col_names) + ), + self.sparkSession, + ) def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> "DataFrame": """Returns a new :class:`DataFrame` by updating an existing column with metadata. diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index b7b4fdcd287b3..782a7ae31f2db 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -77,6 +77,11 @@ def test_to_pandas_from_mixed_dataframe(self): def test_toDF_with_string(self): super().test_toDF_with_string() + # TODO(SPARK-46260): DataFrame.withColumnsRenamed should respect the dict ordering + @unittest.skip("Fails in Spark Connect, should enable.") + def test_ordering_of_with_columns_renamed(self): + super().test_ordering_of_with_columns_renamed() + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 52806f4f4a382..c25fe60ad174c 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -163,6 +163,15 @@ def test_with_columns_renamed(self): message_parameters={"arg_name": "colsMap", "arg_type": "tuple"}, ) + def test_ordering_of_with_columns_renamed(self): + df = self.spark.range(10) + + df1 = df.withColumnsRenamed({"id": "a", "a": "b"}) + self.assertEqual(df1.columns, ["b"]) + + df2 = df.withColumnsRenamed({"a": "b", "id": "a"}) + self.assertEqual(df2.columns, ["a"]) + def test_drop_duplicates(self): # SPARK-36034 test that drop duplicates throws a type error when in correct type provided df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 293f20c453aee..cacc193885d7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2922,18 +2922,29 @@ class Dataset[T] private[sql]( */ @throws[AnalysisException] def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin { + val (colNames, newColNames) = colsMap.toSeq.unzip + withColumnsRenamed(colNames, newColNames) + } + + private def withColumnsRenamed( + colNames: Seq[String], + newColNames: Seq[String]): DataFrame = withOrigin { + require(colNames.size == newColNames.size, + s"The size of existing column names: ${colNames.size} isn't equal to " + + s"the size of new column names: ${newColNames.size}") + val resolver = sparkSession.sessionState.analyzer.resolver val output: Seq[NamedExpression] = queryExecution.analyzed.output - val projectList = colsMap.foldLeft(output) { + val projectList = colNames.zip(newColNames).foldLeft(output) { case (attrs, (existingName, newName)) => - attrs.map(attr => - if (resolver(attr.name, existingName)) { - Alias(attr, newName)() - } else { - attr - } - ) + attrs.map(attr => + if (resolver(attr.name, existingName)) { + Alias(attr, newName)() + } else { + attr + } + ) } SchemaUtils.checkColumnNameDuplication( projectList.map(_.name), From f9658d31d4cc42e73a0022c7e1c1441cad9aab37 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 5 Dec 2023 13:04:25 +0800 Subject: [PATCH 2/4] nit --- python/pyspark/sql/tests/connect/test_parity_dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 782a7ae31f2db..36f9fc48dd83b 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -77,7 +77,7 @@ def test_to_pandas_from_mixed_dataframe(self): def test_toDF_with_string(self): super().test_toDF_with_string() - # TODO(SPARK-46260): DataFrame.withColumnsRenamed should respect the dict ordering + # TODO(SPARK-46261): DataFrame.withColumnsRenamed should respect the dict ordering @unittest.skip("Fails in Spark Connect, should enable.") def test_ordering_of_with_columns_renamed(self): super().test_ordering_of_with_columns_renamed() From e8f6594cdbefc34c78dd94e496890b13c6d72fd7 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 5 Dec 2023 13:25:37 +0800 Subject: [PATCH 3/4] scala test --- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b732f6631a70c..a98006d1387fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -987,6 +987,12 @@ class DataFrameSuite extends QueryTest parameters = Map("columnName" -> "`age`")) } + test("SPARK-46260: withColumnsRenamed should respect the Map ordering") { + val df = spark.range(10).toDF() + assert(df.withColumnsRenamed(Map("id" -> "a", "a" -> "b")).columns === Array("b")) + assert(df.withColumnsRenamed(Map("a" -> "b", "id" -> "a")).columns === Array("a")) + } + test("SPARK-20384: Value class filter") { val df = spark.sparkContext .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c"))) From a9c5060bbcd037870590cc5f8b56d3335d7459b9 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 5 Dec 2023 13:39:50 +0800 Subject: [PATCH 4/4] listmap --- python/pyspark/sql/tests/connect/test_parity_dataframe.py | 2 +- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 36f9fc48dd83b..fbef282e0b978 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -77,7 +77,7 @@ def test_to_pandas_from_mixed_dataframe(self): def test_toDF_with_string(self): super().test_toDF_with_string() - # TODO(SPARK-46261): DataFrame.withColumnsRenamed should respect the dict ordering + # TODO(SPARK-46261): Python Client withColumnsRenamed should respect the dict ordering @unittest.skip("Fails in Spark Connect, should enable.") def test_ordering_of_with_columns_renamed(self): super().test_ordering_of_with_columns_renamed() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a98006d1387fe..25ecefd28cf8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp} import java.util.{Locale, UUID} import java.util.concurrent.atomic.AtomicLong +import scala.collection.immutable.ListMap import scala.reflect.runtime.universe.TypeTag import scala.util.Random @@ -989,8 +990,8 @@ class DataFrameSuite extends QueryTest test("SPARK-46260: withColumnsRenamed should respect the Map ordering") { val df = spark.range(10).toDF() - assert(df.withColumnsRenamed(Map("id" -> "a", "a" -> "b")).columns === Array("b")) - assert(df.withColumnsRenamed(Map("a" -> "b", "id" -> "a")).columns === Array("a")) + assert(df.withColumnsRenamed(ListMap("id" -> "a", "a" -> "b")).columns === Array("b")) + assert(df.withColumnsRenamed(ListMap("a" -> "b", "id" -> "a")).columns === Array("a")) } test("SPARK-20384: Value class filter") {