Skip to content

Commit

Permalink
[SPARK-47184][PYTHON][CONNECT][TESTS] Make `test_repartitionByRange_d…
Browse files Browse the repository at this point in the history
…ataframe` reusable

### What changes were proposed in this pull request?
Make `test_repartitionByRange_dataframe` reusable

### Why are the changes needed?
to make it reusable in Spark Connect

### Does this PR introduce _any_ user-facing change?
no, test-only

### How was this patch tested?
updated ut

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#45281 from zhengruifeng/connect_test_repartitionByRange_dataframe.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng authored and TakawaAkirayo committed Mar 4, 2024
1 parent a817385 commit 2299d06
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
4 changes: 0 additions & 4 deletions python/pyspark/sql/tests/connect/test_parity_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ def test_observe_str(self):
def test_pandas_api(self):
super().test_pandas_api()

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_repartitionByRange_dataframe(self):
super().test_repartitionByRange_dataframe()

@unittest.skip("Spark Connect does not SparkContext but the tests depend on them.")
def test_same_semantics_error(self):
super().test_same_semantics_error()
Expand Down
19 changes: 10 additions & 9 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from contextlib import redirect_stdout

from pyspark.sql import SparkSession, Row, functions
from pyspark.sql.functions import col, lit, count, sum, mean, struct
from pyspark.sql.functions import col, lit, count, sum, mean, struct, spark_partition_id
from pyspark.sql.types import (
StringType,
IntegerType,
Expand Down Expand Up @@ -483,20 +483,21 @@ def test_repartitionByRange_dataframe(self):

# test repartitionByRange(numPartitions, *cols)
df3 = df1.repartitionByRange(2, "name", "age")
self.assertEqual(df3.rdd.getNumPartitions(), 2)
self.assertEqual(df3.rdd.first(), df2.rdd.first())
self.assertEqual(df3.rdd.take(3), df2.rdd.take(3))

self.assertEqual(df3.select(spark_partition_id()).distinct().count(), 2)
self.assertEqual(df3.first(), df2.first())
self.assertEqual(df3.take(3), df2.take(3))

# test repartitionByRange(numPartitions, *cols)
df4 = df1.repartitionByRange(3, "name", "age")
self.assertEqual(df4.rdd.getNumPartitions(), 3)
self.assertEqual(df4.rdd.first(), df2.rdd.first())
self.assertEqual(df4.rdd.take(3), df2.rdd.take(3))
self.assertEqual(df4.select(spark_partition_id()).distinct().count(), 3)
self.assertEqual(df4.first(), df2.first())
self.assertEqual(df4.take(3), df2.take(3))

# test repartitionByRange(*cols)
df5 = df1.repartitionByRange(5, "name", "age")
self.assertEqual(df5.rdd.first(), df2.rdd.first())
self.assertEqual(df5.rdd.take(3), df2.rdd.take(3))
self.assertEqual(df5.first(), df2.first())
self.assertEqual(df5.take(3), df2.take(3))

with self.assertRaises(PySparkTypeError) as pe:
df1.repartitionByRange([10], "name", "age")
Expand Down

0 comments on commit 2299d06

Please sign in to comment.