From 347580f2a4439aae8c898b111f9b6153fd9d5f3e Mon Sep 17 00:00:00 2001 From: Maddie Dawson Date: Mon, 22 May 2023 13:11:28 -0700 Subject: [PATCH] Address comments --- src/datasets/arrow_dataset.py | 2 +- src/datasets/packaged_modules/spark/spark.py | 8 ++------ tests/packaged_modules/test_spark.py | 2 ++ 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 823d543fa574..054f752526b5 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1248,7 +1248,7 @@ def from_spark( Whether to copy the data in-memory. working_dir (`str`, *optional*) Intermediate directory for each Spark worker to write data to before moving it to `cache_dir`. Setting - a non-NFS intermediate directory may improve performance. + a non-NFS intermediate directory may improve performance. Can also be set via env var HF_WORKING_DIR. load_from_cache_file (`bool`): Whether to load the dataset from the cache if possible. diff --git a/src/datasets/packaged_modules/spark/spark.py b/src/datasets/packaged_modules/spark/spark.py index de516e126bb7..10e3d373ba3e 100644 --- a/src/datasets/packaged_modules/spark/spark.py +++ b/src/datasets/packaged_modules/spark/spark.py @@ -91,7 +91,6 @@ def __init__( self._spark = pyspark.sql.SparkSession.builder.getOrCreate() self.df = df - self._validate_cache_dir(cache_dir) self._working_dir = working_dir super().__init__( @@ -136,8 +135,6 @@ def _split_generators(self, dl_manager: datasets.download.download_manager.Downl return [datasets.SplitGenerator(name=datasets.Split.TRAIN)] def _repartition_df_if_needed(self, max_shard_size): - import pyspark - def get_arrow_batch_size(it): for batch in it: yield pa.RecordBatch.from_pydict({"batch_bytes": [batch.nbytes]}) @@ -166,10 +163,9 @@ def _prepare_split_single( file_format: str, max_shard_size: int, ) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: - import pyspark - writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter - working_fpath = os.path.join(self._working_dir, os.path.basename(fpath)) if self._working_dir else fpath + working_dir = self._working_dir or os.environ.get("HF_WORKING_DIR") + working_fpath = os.path.join(working_dir, os.path.basename(fpath)) if working_dir else fpath embed_local_files = file_format == "parquet" # Define these so that we don't reference self in write_arrow, which will result in a pickling error due to diff --git a/tests/packaged_modules/test_spark.py b/tests/packaged_modules/test_spark.py index 3b1c32c91d19..610c5bb53f7f 100644 --- a/tests/packaged_modules/test_spark.py +++ b/tests/packaged_modules/test_spark.py @@ -22,6 +22,7 @@ def _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order) expected_row_ids_and_row_dicts.append((f"{part_id}_{row_idx}", row.asDict())) return expected_row_ids_and_row_dicts + @require_not_windows @require_dill_gt_0_3_2 def test_repartition_df_if_needed(): @@ -104,6 +105,7 @@ def test_spark_examples_iterable_shard(): assert row_id == expected_row_id assert row_dict == expected_row_dict + def test_repartition_df_if_needed_max_num_df_rows(): spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() df = spark.range(100).repartition(1)