Skip to content

Commit

Permalink
[SPARK-46260][PYTHON][SQL] DataFrame.withColumnsRenamed` should respe…
Browse files Browse the repository at this point in the history
…ct the dict ordering

### What changes were proposed in this pull request?
Make `DataFrame.withColumnsRenamed` respect the dict ordering

### Why are the changes needed?
the ordering in `withColumnsRenamed` matters

in scala
```
scala> val df = spark.range(1000)
val df: org.apache.spark.sql.Dataset[Long] = [id: bigint]

scala> df.withColumnsRenamed(Map("id" -> "a", "a" -> "b"))
val res0: org.apache.spark.sql.DataFrame = [b: bigint]

scala> df.withColumnsRenamed(Map("a" -> "b", "id" -> "a"))
val res1: org.apache.spark.sql.DataFrame = [a: bigint]
```

However, in py4j the Python `dict` -> JVM `map` conversion can not guarantee the ordering

### Does this PR introduce _any_ user-facing change?
yes, behavior change

before this PR
```
In [1]: df = spark.range(10)

In [2]: df.withColumnsRenamed({"id": "a", "a": "b"})
Out[2]: DataFrame[a: bigint]

In [3]: df.withColumnsRenamed({"a": "b", "id": "a"})
Out[3]: DataFrame[a: bigint]
```

after this PR
```
In [1]: df = spark.range(10)

In [2]: df.withColumnsRenamed({"id": "a", "a": "b"})
Out[2]: DataFrame[b: bigint]

In [3]: df.withColumnsRenamed({"a": "b", "id": "a"})
Out[3]: DataFrame[a: bigint]
```

### How was this patch tested?
added ut

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #44177 from zhengruifeng/sql_withColumnsRenamed_sql.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Dec 6, 2023
1 parent 35a99a8 commit 032e782
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 9 deletions.
13 changes: 12 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6272,7 +6272,18 @@ def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame":
message_parameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__},
)

return DataFrame(self._jdf.withColumnsRenamed(colsMap), self.sparkSession)
col_names: List[str] = []
new_col_names: List[str] = []
for k, v in colsMap.items():
col_names.append(k)
new_col_names.append(v)

return DataFrame(
self._jdf.withColumnsRenamed(
_to_seq(self._sc, col_names), _to_seq(self._sc, new_col_names)
),
self.sparkSession,
)

def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> "DataFrame":
"""Returns a new :class:`DataFrame` by updating an existing column with metadata.
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def test_to_pandas_from_mixed_dataframe(self):
def test_toDF_with_string(self):
super().test_toDF_with_string()

# TODO(SPARK-46261): Python Client withColumnsRenamed should respect the dict ordering
@unittest.skip("Fails in Spark Connect, should enable.")
def test_ordering_of_with_columns_renamed(self):
super().test_ordering_of_with_columns_renamed()


if __name__ == "__main__":
import unittest
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ def test_with_columns_renamed(self):
message_parameters={"arg_name": "colsMap", "arg_type": "tuple"},
)

def test_ordering_of_with_columns_renamed(self):
df = self.spark.range(10)

df1 = df.withColumnsRenamed({"id": "a", "a": "b"})
self.assertEqual(df1.columns, ["b"])

df2 = df.withColumnsRenamed({"a": "b", "id": "a"})
self.assertEqual(df2.columns, ["a"])

def test_drop_duplicates(self):
# SPARK-36034 test that drop duplicates throws a type error when in correct type provided
df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"])
Expand Down
27 changes: 19 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2922,18 +2922,29 @@ class Dataset[T] private[sql](
*/
@throws[AnalysisException]
def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin {
val (colNames, newColNames) = colsMap.toSeq.unzip
withColumnsRenamed(colNames, newColNames)
}

private def withColumnsRenamed(
colNames: Seq[String],
newColNames: Seq[String]): DataFrame = withOrigin {
require(colNames.size == newColNames.size,
s"The size of existing column names: ${colNames.size} isn't equal to " +
s"the size of new column names: ${newColNames.size}")

val resolver = sparkSession.sessionState.analyzer.resolver
val output: Seq[NamedExpression] = queryExecution.analyzed.output

val projectList = colsMap.foldLeft(output) {
val projectList = colNames.zip(newColNames).foldLeft(output) {
case (attrs, (existingName, newName)) =>
attrs.map(attr =>
if (resolver(attr.name, existingName)) {
Alias(attr, newName)()
} else {
attr
}
)
attrs.map(attr =>
if (resolver(attr.name, existingName)) {
Alias(attr, newName)()
} else {
attr
}
)
}
SchemaUtils.checkColumnNameDuplication(
projectList.map(_.name),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp}
import java.util.{Locale, UUID}
import java.util.concurrent.atomic.AtomicLong

import scala.collection.immutable.ListMap
import scala.reflect.runtime.universe.TypeTag
import scala.util.Random

Expand Down Expand Up @@ -987,6 +988,12 @@ class DataFrameSuite extends QueryTest
parameters = Map("columnName" -> "`age`"))
}

test("SPARK-46260: withColumnsRenamed should respect the Map ordering") {
val df = spark.range(10).toDF()
assert(df.withColumnsRenamed(ListMap("id" -> "a", "a" -> "b")).columns === Array("b"))
assert(df.withColumnsRenamed(ListMap("a" -> "b", "id" -> "a")).columns === Array("a"))
}

test("SPARK-20384: Value class filter") {
val df = spark.sparkContext
.parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c")))
Expand Down

0 comments on commit 032e782

Please sign in to comment.