Skip to content

Commit

Permalink
Avoid PyArrow type optimization if it fails (#3234)
Browse files Browse the repository at this point in the history
* Add option to disable type optimization

* Add a test

* Add DISABLE prefix

* Style

* Revert changes

* Remove col in TypedSequence

* Add fallback in case of range error

* Add test

* Fix

* Log info message
  • Loading branch information
mariosasko authored Nov 10, 2021
1 parent ec37b34 commit 807341d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __arrow_array__(self, type=None):
trying_type = True
else:
type = self.type
trying_int_optimization = False
try:
if isinstance(type, _ArrayXDExtensionType):
if isinstance(self.data, np.ndarray):
Expand Down Expand Up @@ -130,6 +131,7 @@ def __arrow_array__(self, type=None):
"Specified try_type alters data. Please check that the type/feature that you provided match the type/features of the data."
)
if self.optimized_int_type and self.type is None and self.try_type is None:
trying_int_optimization = True
if pa.types.is_int64(out.type):
out = out.cast(self.optimized_int_type)
elif pa.types.is_list(out.type):
Expand All @@ -154,6 +156,10 @@ def __arrow_array__(self, type=None):
type_(self.data), e
)
) from None
elif trying_int_optimization and "not in range" in str(e):
optimized_int_type_str = np.dtype(self.optimized_int_type.to_pandas_dtype()).name
logger.info(f"Failed to cast a sequence to {optimized_int_type_str}. Falling back to int64.")
return out
else:
raise
elif "overflow" in str(e):
Expand All @@ -162,6 +168,10 @@ def __arrow_array__(self, type=None):
type_(self.data), e
)
) from None
elif trying_int_optimization and "not in range" in str(e):
optimized_int_type_str = np.dtype(self.optimized_int_type.to_pandas_dtype()).name
logger.info(f"Failed to cast a sequence to {optimized_int_type_str}. Falling back to int64.")
return out
else:
raise

Expand Down
19 changes: 19 additions & 0 deletions tests/test_arrow_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import os
import tempfile
from unittest import TestCase

import numpy as np
import pyarrow as pa
import pytest

Expand Down Expand Up @@ -211,6 +213,13 @@ def get_base_dtype(arr_type):
return arr_type


def change_first_primitive_element_in_list(lst, value):
if isinstance(lst[0], list):
change_first_primitive_element_in_list(lst[0], value)
else:
lst[0] = value


@pytest.mark.parametrize("optimized_int_type, expected_dtype", [(None, pa.int64()), (pa.int32(), pa.int32())])
@pytest.mark.parametrize("sequence", [[1, 2, 3], [[1, 2, 3]], [[[1, 2, 3]]]])
def test_optimized_int_type_for_typed_sequence(sequence, optimized_int_type, expected_dtype):
Expand All @@ -230,9 +239,19 @@ def test_optimized_int_type_for_typed_sequence(sequence, optimized_int_type, exp
)
@pytest.mark.parametrize("sequence", [[1, 2, 3], [[1, 2, 3]], [[[1, 2, 3]]]])
def test_optimized_typed_sequence(sequence, col, expected_dtype):
# in range
arr = pa.array(OptimizedTypedSequence(sequence, col=col))
assert get_base_dtype(arr.type) == expected_dtype

# not in range
if col != "other":
# avoids errors due to in-place modifications
sequence = copy.deepcopy(sequence)
value = np.iinfo(expected_dtype.to_pandas_dtype()).max + 1
change_first_primitive_element_in_list(sequence, value)
arr = pa.array(OptimizedTypedSequence(sequence, col=col))
assert get_base_dtype(arr.type) == pa.int64()


@pytest.mark.parametrize("raise_exception", [False, True])
def test_arrow_writer_closes_stream(raise_exception, tmp_path):
Expand Down

1 comment on commit 807341d

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==3.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.070565 / 0.011353 (0.059212) 0.004222 / 0.011008 (-0.006786) 0.031360 / 0.038508 (-0.007148) 0.036070 / 0.023109 (0.012960) 0.334885 / 0.275898 (0.058987) 0.366230 / 0.323480 (0.042750) 0.084525 / 0.007986 (0.076539) 0.005220 / 0.004328 (0.000891) 0.009285 / 0.004250 (0.005035) 0.042954 / 0.037052 (0.005902) 0.334573 / 0.258489 (0.076084) 0.366421 / 0.293841 (0.072580) 0.085349 / 0.128546 (-0.043198) 0.009087 / 0.075646 (-0.066559) 0.252446 / 0.419271 (-0.166825) 0.046128 / 0.043533 (0.002595) 0.333806 / 0.255139 (0.078667) 0.360199 / 0.283200 (0.076999) 0.084323 / 0.141683 (-0.057360) 1.717922 / 1.452155 (0.265768) 1.734924 / 1.492716 (0.242208)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.326154 / 0.018006 (0.308147) 0.562520 / 0.000490 (0.562030) 0.003714 / 0.000200 (0.003514) 0.000123 / 0.000054 (0.000068)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.037533 / 0.037411 (0.000121) 0.021863 / 0.014526 (0.007337) 0.028250 / 0.176557 (-0.148307) 0.197289 / 0.737135 (-0.539846) 0.028950 / 0.296338 (-0.267389)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.421013 / 0.215209 (0.205804) 4.194139 / 2.077655 (2.116484) 1.772151 / 1.504120 (0.268031) 1.556143 / 1.541195 (0.014949) 1.639803 / 1.468490 (0.171313) 0.419274 / 4.584777 (-4.165503) 4.641297 / 3.745712 (0.895585) 2.184463 / 5.269862 (-3.085399) 0.864159 / 4.565676 (-3.701517) 0.050394 / 0.424275 (-0.373881) 0.010830 / 0.007607 (0.003223) 0.524463 / 0.226044 (0.298418) 5.262250 / 2.268929 (2.993321) 2.254825 / 55.444624 (-53.189800) 1.884926 / 6.876477 (-4.991551) 2.050485 / 2.142072 (-0.091588) 0.539522 / 4.805227 (-4.265705) 0.113308 / 6.500664 (-6.387356) 0.056454 / 0.075469 (-0.019015)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.522269 / 1.841788 (-0.319518) 12.289042 / 8.074308 (4.214734) 26.952208 / 10.191392 (16.760816) 0.677184 / 0.680424 (-0.003239) 0.502048 / 0.534201 (-0.032153) 0.369315 / 0.579283 (-0.209969) 0.499788 / 0.434364 (0.065424) 0.256747 / 0.540337 (-0.283591) 0.266148 / 1.386936 (-1.120788)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.068624 / 0.011353 (0.057271) 0.004112 / 0.011008 (-0.006896) 0.029699 / 0.038508 (-0.008809) 0.033696 / 0.023109 (0.010587) 0.298379 / 0.275898 (0.022481) 0.329870 / 0.323480 (0.006390) 0.088594 / 0.007986 (0.080608) 0.004365 / 0.004328 (0.000036) 0.007613 / 0.004250 (0.003363) 0.038300 / 0.037052 (0.001248) 0.300713 / 0.258489 (0.042224) 0.339834 / 0.293841 (0.045993) 0.084313 / 0.128546 (-0.044234) 0.008995 / 0.075646 (-0.066651) 0.251252 / 0.419271 (-0.168019) 0.045451 / 0.043533 (0.001918) 0.302640 / 0.255139 (0.047501) 0.319974 / 0.283200 (0.036774) 0.084675 / 0.141683 (-0.057008) 1.673895 / 1.452155 (0.221741) 1.740900 / 1.492716 (0.248183)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.307621 / 0.018006 (0.289614) 0.573892 / 0.000490 (0.573402) 0.002232 / 0.000200 (0.002033) 0.000090 / 0.000054 (0.000035)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.033202 / 0.037411 (-0.004209) 0.021772 / 0.014526 (0.007247) 0.030534 / 0.176557 (-0.146023) 0.201474 / 0.737135 (-0.535662) 0.032969 / 0.296338 (-0.263370)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.427302 / 0.215209 (0.212093) 4.283054 / 2.077655 (2.205399) 1.829680 / 1.504120 (0.325560) 1.615021 / 1.541195 (0.073826) 1.697338 / 1.468490 (0.228848) 0.420858 / 4.584777 (-4.163919) 4.697026 / 3.745712 (0.951314) 2.138258 / 5.269862 (-3.131603) 0.898892 / 4.565676 (-3.666785) 0.050825 / 0.424275 (-0.373450) 0.011083 / 0.007607 (0.003475) 0.960678 / 0.226044 (0.734633) 11.423592 / 2.268929 (9.154663) 2.358918 / 55.444624 (-53.085706) 1.941956 / 6.876477 (-4.934521) 2.067300 / 2.142072 (-0.074773) 0.539886 / 4.805227 (-4.265341) 0.114267 / 6.500664 (-6.386397) 0.056103 / 0.075469 (-0.019366)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.570715 / 1.841788 (-0.271073) 18.542675 / 8.074308 (10.468367) 27.742867 / 10.191392 (17.551475) 0.850269 / 0.680424 (0.169845) 0.527612 / 0.534201 (-0.006589) 0.375970 / 0.579283 (-0.203314) 0.509188 / 0.434364 (0.074825) 0.264834 / 0.540337 (-0.275504) 0.280697 / 1.386936 (-1.106239)

CML watermark

Please sign in to comment.