diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5211d874ba33e..1419d1f3cb635 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -6272,7 +6272,18 @@ def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame": message_parameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__}, ) - return DataFrame(self._jdf.withColumnsRenamed(colsMap), self.sparkSession) + col_names: List[str] = [] + new_col_names: List[str] = [] + for k, v in colsMap.items(): + col_names.append(k) + new_col_names.append(v) + + return DataFrame( + self._jdf.withColumnsRenamed( + _to_seq(self._sc, col_names), _to_seq(self._sc, new_col_names) + ), + self.sparkSession, + ) def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> "DataFrame": """Returns a new :class:`DataFrame` by updating an existing column with metadata. diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index b7b4fdcd287b3..fbef282e0b978 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -77,6 +77,11 @@ def test_to_pandas_from_mixed_dataframe(self): def test_toDF_with_string(self): super().test_toDF_with_string() + # TODO(SPARK-46261): Python Client withColumnsRenamed should respect the dict ordering + @unittest.skip("Fails in Spark Connect, should enable.") + def test_ordering_of_with_columns_renamed(self): + super().test_ordering_of_with_columns_renamed() + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 52806f4f4a382..c25fe60ad174c 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -163,6 +163,15 @@ def test_with_columns_renamed(self): message_parameters={"arg_name": "colsMap", "arg_type": "tuple"}, ) + def test_ordering_of_with_columns_renamed(self): + df = self.spark.range(10) + + df1 = df.withColumnsRenamed({"id": "a", "a": "b"}) + self.assertEqual(df1.columns, ["b"]) + + df2 = df.withColumnsRenamed({"a": "b", "id": "a"}) + self.assertEqual(df2.columns, ["a"]) + def test_drop_duplicates(self): # SPARK-36034 test that drop duplicates throws a type error when in correct type provided df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 293f20c453aee..cacc193885d7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2922,18 +2922,29 @@ class Dataset[T] private[sql]( */ @throws[AnalysisException] def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin { + val (colNames, newColNames) = colsMap.toSeq.unzip + withColumnsRenamed(colNames, newColNames) + } + + private def withColumnsRenamed( + colNames: Seq[String], + newColNames: Seq[String]): DataFrame = withOrigin { + require(colNames.size == newColNames.size, + s"The size of existing column names: ${colNames.size} isn't equal to " + + s"the size of new column names: ${newColNames.size}") + val resolver = sparkSession.sessionState.analyzer.resolver val output: Seq[NamedExpression] = queryExecution.analyzed.output - val projectList = colsMap.foldLeft(output) { + val projectList = colNames.zip(newColNames).foldLeft(output) { case (attrs, (existingName, newName)) => - attrs.map(attr => - if (resolver(attr.name, existingName)) { - Alias(attr, newName)() - } else { - attr - } - ) + attrs.map(attr => + if (resolver(attr.name, existingName)) { + Alias(attr, newName)() + } else { + attr + } + ) } SchemaUtils.checkColumnNameDuplication( projectList.map(_.name), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b732f6631a70c..25ecefd28cf8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp} import java.util.{Locale, UUID} import java.util.concurrent.atomic.AtomicLong +import scala.collection.immutable.ListMap import scala.reflect.runtime.universe.TypeTag import scala.util.Random @@ -987,6 +988,12 @@ class DataFrameSuite extends QueryTest parameters = Map("columnName" -> "`age`")) } + test("SPARK-46260: withColumnsRenamed should respect the Map ordering") { + val df = spark.range(10).toDF() + assert(df.withColumnsRenamed(ListMap("id" -> "a", "a" -> "b")).columns === Array("b")) + assert(df.withColumnsRenamed(ListMap("a" -> "b", "id" -> "a")).columns === Array("a")) + } + test("SPARK-20384: Value class filter") { val df = spark.sparkContext .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c")))