Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use default cast for sliced list arrays if pyarrow >= 4 #2497

Merged
merged 3 commits into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@

logger = get_logger(__name__)

if int(pa.__version__.split(".")[0]) == 0:
if int(config.PYARROW_VERSION.split(".")[0]) == 0:
PYARROW_V0 = True
else:
PYARROW_V0 = False
Expand Down Expand Up @@ -935,7 +935,7 @@ def cast_(
schema = pa.schema({col_name: type[col_name].type for col_name in self._data.column_names})
dataset = self.with_format("arrow")
dataset = dataset.map(
lambda t: cast_with_sliced_list_support(t, schema),
lambda t: t.cast(schema) if config.PYARROW_VERSION >= "4" else cast_with_sliced_list_support(t, schema),
batched=True,
batch_size=batch_size,
keep_in_memory=keep_in_memory,
Expand Down Expand Up @@ -994,7 +994,7 @@ def cast(
format = self.format
dataset = self.with_format("arrow")
dataset = dataset.map(
lambda t: cast_with_sliced_list_support(t, schema),
lambda t: t.cast(schema) if config.PYARROW_VERSION >= "4" else cast_with_sliced_list_support(t, schema),
batched=True,
batch_size=batch_size,
keep_in_memory=keep_in_memory,
Expand Down
2 changes: 2 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@


# Imports
PYARROW_VERSION = importlib_metadata.version("pyarrow")

USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()

Expand Down
3 changes: 1 addition & 2 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import numpy as np
import posixpath
import pyarrow as pa
import requests
from tqdm.auto import tqdm

Expand Down Expand Up @@ -371,7 +370,7 @@ def cached_path(

def get_datasets_user_agent(user_agent: Optional[Union[str, dict]] = None) -> str:
ua = "datasets/{}; python/{}".format(__version__, config.PY_VERSION)
ua += "; pyarrow/{}".format(pa.__version__)
ua += "; pyarrow/{}".format(config.PYARROW_VERSION)
if config.TORCH_AVAILABLE:
ua += "; torch/{}".format(config.TORCH_VERSION)
if config.TF_AVAILABLE:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from packaging import version

from datasets import config
from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, TypedSequence
from datasets.features import Array2DExtensionType
from datasets.keyhash import DuplicatedKeysError, InvalidKeyError
Expand Down Expand Up @@ -57,7 +58,7 @@ def test_try_incompatible_extension_type(self):
self.assertEqual(arr.type, pa.string())

def test_catch_overflow(self):
if version.parse(pa.__version__) < version.parse("2.0.0"):
if version.parse(config.PYARROW_VERSION) < version.parse("2.0.0"):
with self.assertRaises(OverflowError):
_ = pa.array(TypedSequence([["x" * 1024]] * ((2 << 20) + 1))) # ListArray with a bit more than 2GB

Expand Down
3 changes: 2 additions & 1 deletion tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pyarrow as pa
import pytest

from datasets import config
from datasets.table import (
ConcatenationTable,
InMemoryTable,
Expand Down Expand Up @@ -745,7 +746,7 @@ def test_concatenation_table_cast(
for k, v in zip(in_memory_pa_table.schema.names, in_memory_pa_table.schema.types)
}
)
if pa.__version__ < "4":
if config.PYARROW_VERSION < "4":
with pytest.raises(pa.ArrowNotImplementedError):
ConcatenationTable.from_blocks(blocks).cast(schema)
else:
Expand Down