Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
maddiedawson committed May 22, 2023
1 parent ccf5ba8 commit 347580f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ def from_spark(
Whether to copy the data in-memory.
working_dir (`str`, *optional*)
Intermediate directory for each Spark worker to write data to before moving it to `cache_dir`. Setting
a non-NFS intermediate directory may improve performance.
a non-NFS intermediate directory may improve performance. Can also be set via env var HF_WORKING_DIR.
load_from_cache_file (`bool`):
Whether to load the dataset from the cache if possible.
Expand Down
8 changes: 2 additions & 6 deletions src/datasets/packaged_modules/spark/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def __init__(

self._spark = pyspark.sql.SparkSession.builder.getOrCreate()
self.df = df
self._validate_cache_dir(cache_dir)
self._working_dir = working_dir

super().__init__(
Expand Down Expand Up @@ -136,8 +135,6 @@ def _split_generators(self, dl_manager: datasets.download.download_manager.Downl
return [datasets.SplitGenerator(name=datasets.Split.TRAIN)]

def _repartition_df_if_needed(self, max_shard_size):
import pyspark

def get_arrow_batch_size(it):
for batch in it:
yield pa.RecordBatch.from_pydict({"batch_bytes": [batch.nbytes]})
Expand Down Expand Up @@ -166,10 +163,9 @@ def _prepare_split_single(
file_format: str,
max_shard_size: int,
) -> Iterable[Tuple[int, bool, Union[int, tuple]]]:
import pyspark

writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
working_fpath = os.path.join(self._working_dir, os.path.basename(fpath)) if self._working_dir else fpath
working_dir = self._working_dir or os.environ.get("HF_WORKING_DIR")
working_fpath = os.path.join(working_dir, os.path.basename(fpath)) if working_dir else fpath
embed_local_files = file_format == "parquet"

# Define these so that we don't reference self in write_arrow, which will result in a pickling error due to
Expand Down
2 changes: 2 additions & 0 deletions tests/packaged_modules/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order)
expected_row_ids_and_row_dicts.append((f"{part_id}_{row_idx}", row.asDict()))
return expected_row_ids_and_row_dicts


@require_not_windows
@require_dill_gt_0_3_2
def test_repartition_df_if_needed():
Expand Down Expand Up @@ -104,6 +105,7 @@ def test_spark_examples_iterable_shard():
assert row_id == expected_row_id
assert row_dict == expected_row_dict


def test_repartition_df_if_needed_max_num_df_rows():
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
df = spark.range(100).repartition(1)
Expand Down

0 comments on commit 347580f

Please sign in to comment.