Skip to content

Commit

Permalink
Use default cast for sliced list arrays if pyarrow >= 4 (#2497)
Browse files Browse the repository at this point in the history
* Set PYARROW_VERSION in config

* Use default cast for sliced list arrays if pyarrow >= 4

* Fix style
  • Loading branch information
albertvillanova authored Jun 14, 2021
1 parent 7aedad6 commit d6d0ede
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 7 deletions.
6 changes: 3 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

logger = get_logger(__name__)

if int(pa.__version__.split(".")[0]) == 0:
if int(config.PYARROW_VERSION.split(".")[0]) == 0:
PYARROW_V0 = True
else:
PYARROW_V0 = False
Expand Down Expand Up @@ -935,7 +935,7 @@ def cast_(
schema = pa.schema({col_name: type[col_name].type for col_name in self._data.column_names})
dataset = self.with_format("arrow")
dataset = dataset.map(
lambda t: cast_with_sliced_list_support(t, schema),
lambda t: t.cast(schema) if config.PYARROW_VERSION >= "4" else cast_with_sliced_list_support(t, schema),
batched=True,
batch_size=batch_size,
keep_in_memory=keep_in_memory,
Expand Down Expand Up @@ -994,7 +994,7 @@ def cast(
format = self.format
dataset = self.with_format("arrow")
dataset = dataset.map(
lambda t: cast_with_sliced_list_support(t, schema),
lambda t: t.cast(schema) if config.PYARROW_VERSION >= "4" else cast_with_sliced_list_support(t, schema),
batched=True,
batch_size=batch_size,
keep_in_memory=keep_in_memory,
Expand Down
2 changes: 2 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@


# Imports
PYARROW_VERSION = importlib_metadata.version("pyarrow")

USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()

Expand Down
3 changes: 1 addition & 2 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import numpy as np
import posixpath
import pyarrow as pa
import requests
from tqdm.auto import tqdm

Expand Down Expand Up @@ -371,7 +370,7 @@ def cached_path(

def get_datasets_user_agent(user_agent: Optional[Union[str, dict]] = None) -> str:
ua = "datasets/{}; python/{}".format(__version__, config.PY_VERSION)
ua += "; pyarrow/{}".format(pa.__version__)
ua += "; pyarrow/{}".format(config.PYARROW_VERSION)
if config.TORCH_AVAILABLE:
ua += "; torch/{}".format(config.TORCH_VERSION)
if config.TF_AVAILABLE:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from packaging import version

from datasets import config
from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, TypedSequence
from datasets.features import Array2DExtensionType
from datasets.keyhash import DuplicatedKeysError, InvalidKeyError
Expand Down Expand Up @@ -57,7 +58,7 @@ def test_try_incompatible_extension_type(self):
self.assertEqual(arr.type, pa.string())

def test_catch_overflow(self):
if version.parse(pa.__version__) < version.parse("2.0.0"):
if version.parse(config.PYARROW_VERSION) < version.parse("2.0.0"):
with self.assertRaises(OverflowError):
_ = pa.array(TypedSequence([["x" * 1024]] * ((2 << 20) + 1))) # ListArray with a bit more than 2GB

Expand Down
3 changes: 2 additions & 1 deletion tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pyarrow as pa
import pytest

from datasets import config
from datasets.table import (
ConcatenationTable,
InMemoryTable,
Expand Down Expand Up @@ -745,7 +746,7 @@ def test_concatenation_table_cast(
for k, v in zip(in_memory_pa_table.schema.names, in_memory_pa_table.schema.types)
}
)
if pa.__version__ < "4":
if config.PYARROW_VERSION < "4":
with pytest.raises(pa.ArrowNotImplementedError):
ConcatenationTable.from_blocks(blocks).cast(schema)
else:
Expand Down

1 comment on commit d6d0ede

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==1.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.020471 / 0.011353 (0.009118) 0.013975 / 0.011008 (0.002967) 0.044859 / 0.038508 (0.006351) 0.036489 / 0.023109 (0.013380) 0.326674 / 0.275898 (0.050776) 0.354407 / 0.323480 (0.030927) 0.010586 / 0.007986 (0.002600) 0.004959 / 0.004328 (0.000630) 0.010748 / 0.004250 (0.006498) 0.049859 / 0.037052 (0.012807) 0.319081 / 0.258489 (0.060592) 0.367117 / 0.293841 (0.073276) 0.137019 / 0.128546 (0.008473) 0.101542 / 0.075646 (0.025895) 0.383885 / 0.419271 (-0.035387) 0.634647 / 0.043533 (0.591114) 0.312707 / 0.255139 (0.057568) 0.344000 / 0.283200 (0.060800) 1.993315 / 0.141683 (1.851633) 1.584215 / 1.452155 (0.132060) 1.669910 / 1.492716 (0.177193)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.013082 / 0.018006 (-0.004924) 0.510007 / 0.000490 (0.509517) 0.001367 / 0.000200 (0.001167) 0.000057 / 0.000054 (0.000003)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.040025 / 0.037411 (0.002613) 0.024405 / 0.014526 (0.009879) 0.028862 / 0.176557 (-0.147695) 0.043334 / 0.737135 (-0.693802) 0.028959 / 0.296338 (-0.267379)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.369443 / 0.215209 (0.154234) 3.638694 / 2.077655 (1.561040) 1.886351 / 1.504120 (0.382231) 1.661633 / 1.541195 (0.120438) 1.687016 / 1.468490 (0.218526) 5.300121 / 4.584777 (0.715344) 4.773553 / 3.745712 (1.027841) 7.152259 / 5.269862 (1.882397) 5.389717 / 4.565676 (0.824041) 0.528704 / 0.424275 (0.104429) 0.009330 / 0.007607 (0.001723) 0.458836 / 0.226044 (0.232792) 4.673920 / 2.268929 (2.404991) 2.331055 / 55.444624 (-53.113569) 1.953174 / 6.876477 (-4.923303) 2.001134 / 2.142072 (-0.140938) 5.429407 / 4.805227 (0.624179) 4.531457 / 6.500664 (-1.969207) 8.383503 / 0.075469 (8.308034)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 11.155326 / 1.841788 (9.313538) 12.953267 / 8.074308 (4.878959) 26.080907 / 10.191392 (15.889515) 0.741147 / 0.680424 (0.060723) 0.508183 / 0.534201 (-0.026018) 0.671835 / 0.579283 (0.092552) 0.489696 / 0.434364 (0.055332) 0.570656 / 0.540337 (0.030319) 1.306615 / 1.386936 (-0.080321)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.020230 / 0.011353 (0.008877) 0.013603 / 0.011008 (0.002595) 0.042849 / 0.038508 (0.004341) 0.036949 / 0.023109 (0.013840) 0.287314 / 0.275898 (0.011416) 0.318262 / 0.323480 (-0.005217) 0.010518 / 0.007986 (0.002532) 0.004955 / 0.004328 (0.000627) 0.010535 / 0.004250 (0.006285) 0.049014 / 0.037052 (0.011962) 0.283631 / 0.258489 (0.025141) 0.324495 / 0.293841 (0.030654) 0.133625 / 0.128546 (0.005079) 0.097754 / 0.075646 (0.022107) 0.364404 / 0.419271 (-0.054867) 0.388636 / 0.043533 (0.345103) 0.285333 / 0.255139 (0.030194) 0.310612 / 0.283200 (0.027412) 1.537208 / 0.141683 (1.395525) 1.547006 / 1.452155 (0.094851) 1.618002 / 1.492716 (0.125285)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.046014 / 0.018006 (0.028007) 0.515760 / 0.000490 (0.515270) 0.012442 / 0.000200 (0.012242) 0.000222 / 0.000054 (0.000168)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.037741 / 0.037411 (0.000329) 0.023607 / 0.014526 (0.009081) 0.027042 / 0.176557 (-0.149515) 0.049130 / 0.737135 (-0.688006) 0.027485 / 0.296338 (-0.268854)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.350874 / 0.215209 (0.135665) 3.500416 / 2.077655 (1.422761) 1.731923 / 1.504120 (0.227803) 1.567320 / 1.541195 (0.026125) 1.639888 / 1.468490 (0.171398) 5.176285 / 4.584777 (0.591508) 4.646462 / 3.745712 (0.900750) 7.035045 / 5.269862 (1.765183) 5.935103 / 4.565676 (1.369426) 0.525704 / 0.424275 (0.101429) 0.009278 / 0.007607 (0.001671) 0.453470 / 0.226044 (0.227426) 4.512127 / 2.268929 (2.243198) 2.172167 / 55.444624 (-53.272458) 1.843887 / 6.876477 (-5.032590) 1.907953 / 2.142072 (-0.234120) 5.372815 / 4.805227 (0.567587) 4.970137 / 6.500664 (-1.530527) 11.389108 / 0.075469 (11.313639)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 11.225690 / 1.841788 (9.383903) 12.721363 / 8.074308 (4.647055) 25.634397 / 10.191392 (15.443005) 0.765395 / 0.680424 (0.084971) 0.515969 / 0.534201 (-0.018232) 0.617251 / 0.579283 (0.037968) 0.470687 / 0.434364 (0.036323) 0.542168 / 0.540337 (0.001830) 1.282155 / 1.386936 (-0.104781)

CML watermark

Please sign in to comment.