Skip to content

Commit

Permalink
Fix sampling and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maddiedawson committed May 22, 2023
1 parent c2f28db commit ccf5ba8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/datasets/packaged_modules/spark/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,22 @@ def get_arrow_batch_size(it):
yield pa.RecordBatch.from_pydict({"batch_bytes": [batch.nbytes]})

df_num_rows = self.df.count()
# Approximate the size of each row (in Arrow format) by averaging over a 100-row sample.
sample_num_rows = df_num_rows if df_num_rows <= 100 else 100
# Approximate the size of each row (in Arrow format) by averaging over a max-100-row sample.
approx_bytes_per_row = (
self.df.limit(100)
self.df.limit(sample_num_rows)
.repartition(1)
.mapInArrow(get_arrow_batch_size, "batch_bytes: long")
.agg(pyspark.sql.functions.sum("batch_bytes").alias("sample_bytes"))
.collect()[0]
.sample_bytes
/ 100
/ sample_num_rows
)
approx_total_size = approx_bytes_per_row * df_num_rows
if approx_total_size > max_shard_size:
self.df = self.df.repartition(int(approx_total_size / max_shard_size))
# Make sure there is at least one row per partition.
new_num_partitions = min(df_num_rows, int(approx_total_size / max_shard_size))
self.df = self.df.repartition(new_num_partitions)

def _prepare_split_single(
self,
Expand Down
22 changes: 22 additions & 0 deletions tests/packaged_modules/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pyspark

from datasets.packaged_modules.spark.spark import (
Spark,
SparkExamplesIterable,
_generate_iterable_examples,
)
Expand All @@ -21,6 +22,18 @@ def _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order)
expected_row_ids_and_row_dicts.append((f"{part_id}_{row_idx}", row.asDict()))
return expected_row_ids_and_row_dicts

@require_not_windows
@require_dill_gt_0_3_2
def test_repartition_df_if_needed():
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
df = spark.range(100).repartition(1)
spark_builder = Spark(df)
# The id ints will be converted to Pyarrow int64s, so each row will be 8 bytes. Setting a max_shard_size of 16 means
# that each partition can hold 2 rows.
spark_builder._repartition_df_if_needed(max_shard_size=16)
# Given that the dataframe has 100 rows and each partition has 2 rows, we expect 50 partitions.
assert spark_builder.df.rdd.getNumPartitions() == 50


@require_not_windows
@require_dill_gt_0_3_2
Expand Down Expand Up @@ -90,3 +103,12 @@ def test_spark_examples_iterable_shard():
expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_2[i]
assert row_id == expected_row_id
assert row_dict == expected_row_dict

def test_repartition_df_if_needed_max_num_df_rows():
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
df = spark.range(100).repartition(1)
spark_builder = Spark(df)
# Choose a small max_shard_size for maximum partitioning.
spark_builder._repartition_df_if_needed(max_shard_size=1)
# The new number of partitions should not be greater than the number of rows.
assert spark_builder.df.rdd.getNumPartitions() == 100

0 comments on commit ccf5ba8

Please sign in to comment.