diff --git a/setup.py b/setup.py index efd84a29ddc..0e993445b96 100644 --- a/setup.py +++ b/setup.py @@ -185,6 +185,7 @@ "transformers", "typing-extensions>=4.6.1", # due to conflict between apache-beam and pydantic "zstandard", + "polars[timezone]>=0.20.0", ] diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 9c49378471a..c155030c594 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -134,6 +134,7 @@ if TYPE_CHECKING: import sqlite3 + import polars as pl import pyspark import sqlalchemy @@ -868,6 +869,48 @@ def from_pandas( table = table.cast(features.arrow_schema) return cls(table, info=info, split=split) + @classmethod + def from_polars( + cls, + df: "pl.DataFrame", + features: Optional[Features] = None, + info: Optional[DatasetInfo] = None, + split: Optional[NamedSplit] = None, + ) -> "Dataset": + """ + Collect the underlying arrow arrays in an Arrow Table. + + This operation is mostly zero copy. + + Data types that do copy: + * CategoricalType + + Args: + df (`polars.DataFrame`): DataFrame to convert to Arrow Table + features (`Features`, optional): Dataset features. + info (`DatasetInfo`, optional): Dataset information, like description, citation, etc. + split (`NamedSplit`, optional): Name of the dataset split. + + Examples: + ```py + >>> ds = Dataset.from_polars(df) + ``` + """ + if info is not None and features is not None and info.features != features: + raise ValueError( + f"Features specified in `features` and `info.features` can't be different:\n{features}\n{info.features}" + ) + features = features if features is not None else info.features if info is not None else None + if info is None: + info = DatasetInfo() + info.features = features + table = InMemoryTable(df.to_arrow()) + if features is not None: + # more expensive cast than InMemoryTable.from_polars(..., schema=features.arrow_schema) + # needed to support the str to Audio conversion for instance + table = table.cast(features.arrow_schema) + return cls(table, info=info, split=split) + @classmethod def from_dict( cls, @@ -3319,6 +3362,10 @@ def validate_function_output(processed_inputs, indices): ) elif isinstance(indices, list) and isinstance(processed_inputs, Mapping): allowed_batch_return_types = (list, np.ndarray, pd.Series) + if config.POLARS_AVAILABLE and "polars" in sys.modules: + import polars as pl + + allowed_batch_return_types += (pl.Series, pl.DataFrame) if config.TF_AVAILABLE and "tensorflow" in sys.modules: import tensorflow as tf @@ -3438,6 +3485,10 @@ def init_buffer_and_writer(): # If `update_data` is True after processing the first example/batch, initalize these resources with `init_buffer_and_writer` buf_writer, writer, tmp_file = None, None, None + # Check if Polars is available and import it if so + if config.POLARS_AVAILABLE and "polars" in sys.modules: + import polars as pl + # Optionally initialize the writer as a context manager with contextlib.ExitStack() as stack: try: @@ -3464,6 +3515,12 @@ def init_buffer_and_writer(): writer.write_row(example) elif isinstance(example, pd.DataFrame): writer.write_row(pa.Table.from_pandas(example)) + elif ( + config.POLARS_AVAILABLE + and "polars" in sys.modules + and isinstance(example, pl.DataFrame) + ): + writer.write_row(example.to_arrow()) else: writer.write(example) num_examples_progress_update += 1 @@ -3497,6 +3554,10 @@ def init_buffer_and_writer(): writer.write_table(batch) elif isinstance(batch, pd.DataFrame): writer.write_table(pa.Table.from_pandas(batch)) + elif ( + config.POLARS_AVAILABLE and "polars" in sys.modules and isinstance(batch, pl.DataFrame) + ): + writer.write_table(batch.to_arrow()) else: writer.write_batch(batch) num_examples_progress_update += num_examples_in_batch @@ -4949,6 +5010,66 @@ def to_pandas( for offset in range(0, len(self), batch_size) ) + def to_polars( + self, + batch_size: Optional[int] = None, + batched: bool = False, + schema_overrides: Optional[dict] = None, + rechunk: bool = True, + ) -> Union["pl.DataFrame", Iterator["pl.DataFrame"]]: + """Returns the dataset as a `polars.DataFrame`. Can also return a generator for large datasets. + + Args: + batched (`bool`): + Set to `True` to return a generator that yields the dataset as batches + of `batch_size` rows. Defaults to `False` (returns the whole datasets once). + batch_size (`int`, *optional*): + The size (number of rows) of the batches if `batched` is `True`. + Defaults to `genomicsml.datasets.config.DEFAULT_MAX_BATCH_SIZE`. + schema_overrides (`dict`, *optional*): + Support type specification or override of one or more columns; note that + any dtypes inferred from the schema param will be overridden. + rechunk (`bool`): + Make sure that all data is in contiguous memory. Defaults to `True`. + Returns: + `polars.DataFrame` or `Iterator[polars.DataFrame]` + + Example: + + ```py + >>> ds.to_polars() + ``` + """ + if config.POLARS_AVAILABLE: + import polars as pl + + if not batched: + return pl.from_arrow( + query_table( + table=self._data, + key=slice(0, len(self)), + indices=self._indices if self._indices is not None else None, + ), + schema_overrides=schema_overrides, + rechunk=rechunk, + ) + else: + batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE + return ( + pl.from_arrow( + query_table( + table=self._data, + key=slice(offset, offset + batch_size), + indices=self._indices if self._indices is not None else None, + ), + schema_overrides=schema_overrides, + rechunk=rechunk, + ) + for offset in range(0, len(self), batch_size) + ) + else: + raise ValueError("Polars needs to be installed to be able to return Polars dataframes.") + def to_parquet( self, path_or_buf: Union[PathLike, BinaryIO], diff --git a/src/datasets/config.py b/src/datasets/config.py index 32127bea7dc..41c7ff9c3fe 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -61,6 +61,16 @@ else: logger.info("Disabling PyTorch because USE_TF is set") +POLARS_VERSION = "N/A" +POLARS_AVAILABLE = importlib.util.find_spec("polars") is not None + +if POLARS_AVAILABLE: + try: + POLARS_VERSION = version.parse(importlib.metadata.version("polars")) + logger.info(f"Polars version {POLARS_VERSION} available.") + except importlib.metadata.PackageNotFoundError: + pass + TF_VERSION = "N/A" TF_AVAILABLE = False diff --git a/src/datasets/formatting/__init__.py b/src/datasets/formatting/__init__.py index ba1edcff070..78f64cfe912 100644 --- a/src/datasets/formatting/__init__.py +++ b/src/datasets/formatting/__init__.py @@ -80,6 +80,14 @@ def _register_unavailable_formatter( _register_formatter(PandasFormatter, "pandas", aliases=["pd"]) _register_formatter(CustomFormatter, "custom") +if config.POLARS_AVAILABLE: + from .polars_formatter import PolarsFormatter + + _register_formatter(PolarsFormatter, "polars", aliases=["pl"]) +else: + _polars_error = ValueError("Polars needs to be installed to be able to return Polars dataframes.") + _register_unavailable_formatter(_polars_error, "polars", aliases=["pl"]) + if config.TORCH_AVAILABLE: from .torch_formatter import TorchFormatter diff --git a/src/datasets/formatting/polars_formatter.py b/src/datasets/formatting/polars_formatter.py new file mode 100644 index 00000000000..543bde52dd0 --- /dev/null +++ b/src/datasets/formatting/polars_formatter.py @@ -0,0 +1,122 @@ +# Copyright 2020 The HuggingFace Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from collections.abc import Mapping +from functools import partial +from typing import TYPE_CHECKING, Optional + +import pyarrow as pa + +from .. import config +from ..features import Features +from ..features.features import decode_nested_example +from ..utils.py_utils import no_op_if_value_is_null +from .formatting import BaseArrowExtractor, TensorFormatter + + +if TYPE_CHECKING: + import polars as pl + + +class PolarsArrowExtractor(BaseArrowExtractor["pl.DataFrame", "pl.Series", "pl.DataFrame"]): + def extract_row(self, pa_table: pa.Table) -> "pl.DataFrame": + if config.POLARS_AVAILABLE: + if "polars" not in sys.modules: + import polars + else: + polars = sys.modules["polars"] + + return polars.from_arrow(pa_table.slice(length=1)) + else: + raise ValueError("Polars needs to be installed to be able to return Polars dataframes.") + + def extract_column(self, pa_table: pa.Table) -> "pl.Series": + if config.POLARS_AVAILABLE: + if "polars" not in sys.modules: + import polars + else: + polars = sys.modules["polars"] + + return polars.from_arrow(pa_table.select([0]))[pa_table.column_names[0]] + else: + raise ValueError("Polars needs to be installed to be able to return Polars dataframes.") + + def extract_batch(self, pa_table: pa.Table) -> "pl.DataFrame": + if config.POLARS_AVAILABLE: + if "polars" not in sys.modules: + import polars + else: + polars = sys.modules["polars"] + + return polars.from_arrow(pa_table) + else: + raise ValueError("Polars needs to be installed to be able to return Polars dataframes.") + + +class PolarsFeaturesDecoder: + def __init__(self, features: Optional[Features]): + self.features = features + import polars as pl # noqa: F401 - import pl at initialization + + def decode_row(self, row: "pl.DataFrame") -> "pl.DataFrame": + decode = ( + { + column_name: no_op_if_value_is_null(partial(decode_nested_example, feature)) + for column_name, feature in self.features.items() + if self.features._column_requires_decoding[column_name] + } + if self.features + else {} + ) + if decode: + row[list(decode.keys())] = row.map_rows(decode) + return row + + def decode_column(self, column: "pl.Series", column_name: str) -> "pl.Series": + decode = ( + no_op_if_value_is_null(partial(decode_nested_example, self.features[column_name])) + if self.features and column_name in self.features and self.features._column_requires_decoding[column_name] + else None + ) + if decode: + column = column.map_elements(decode) + return column + + def decode_batch(self, batch: "pl.DataFrame") -> "pl.DataFrame": + return self.decode_row(batch) + + +class PolarsFormatter(TensorFormatter[Mapping, "pl.DataFrame", Mapping]): + def __init__(self, features=None, **np_array_kwargs): + super().__init__(features=features) + self.np_array_kwargs = np_array_kwargs + self.polars_arrow_extractor = PolarsArrowExtractor + self.polars_features_decoder = PolarsFeaturesDecoder(features) + import polars as pl # noqa: F401 - import pl at initialization + + def format_row(self, pa_table: pa.Table) -> "pl.DataFrame": + row = self.polars_arrow_extractor().extract_row(pa_table) + row = self.polars_features_decoder.decode_row(row) + return row + + def format_column(self, pa_table: pa.Table) -> "pl.Series": + column = self.polars_arrow_extractor().extract_column(pa_table) + column = self.polars_features_decoder.decode_column(column, pa_table.column_names[0]) + return column + + def format_batch(self, pa_table: pa.Table) -> "pl.DataFrame": + row = self.polars_arrow_extractor().extract_batch(pa_table) + row = self.polars_features_decoder.decode_batch(row) + return row diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index b101c4f5715..188540cc17e 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -58,6 +58,7 @@ require_jax, require_not_windows, require_pil, + require_polars, require_pyspark, require_sqlalchemy, require_tf, @@ -483,6 +484,22 @@ def test_set_format_pandas(self, in_memory): self.assertEqual(len(dset[0].columns), 2) self.assertEqual(dset[0]["col_2"].item(), "a") + @require_polars + def test_set_format_polars(self, in_memory): + import polars as pl + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + dset.set_format(type="polars", columns=["col_1"]) + self.assertEqual(len(dset[0].columns), 1) + self.assertIsInstance(dset[0], pl.DataFrame) + self.assertListEqual(list(dset[0].shape), [1, 1]) + self.assertEqual(dset[0]["col_1"].item(), 3) + + dset.set_format(type="polars", columns=["col_1", "col_2"]) + self.assertEqual(len(dset[0].columns), 2) + self.assertEqual(dset[0]["col_2"].item(), "a") + def test_set_transform(self, in_memory): def transform(batch): return {k: [str(i).upper() for i in v] for k, v in batch.items()} @@ -2365,6 +2382,38 @@ def test_to_pandas(self, in_memory): for col_name in dset.column_names: self.assertEqual(len(dset_to_pandas[col_name]), dset.num_rows) + @require_polars + def test_to_polars(self, in_memory): + with tempfile.TemporaryDirectory() as tmp_dir: + # Batched + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + batch_size = dset.num_rows - 1 + to_polars_generator = dset.to_polars(batched=True, batch_size=batch_size) + + for batch in to_polars_generator: + self.assertIsInstance(batch, sys.modules["polars"].DataFrame) + self.assertListEqual(sorted(batch.columns), sorted(dset.column_names)) + for col_name in dset.column_names: + self.assertLessEqual(len(batch[col_name]), batch_size) + del batch + + # Full + dset_to_polars = dset.to_polars() + self.assertIsInstance(dset_to_polars, sys.modules["polars"].DataFrame) + self.assertListEqual(sorted(dset_to_polars.columns), sorted(dset.column_names)) + for col_name in dset.column_names: + self.assertEqual(len(dset_to_polars[col_name]), len(dset)) + + # With index mapping + with dset.select([1, 0, 3]) as dset: + dset_to_polars = dset.to_polars() + self.assertIsInstance(dset_to_polars, sys.modules["polars"].DataFrame) + self.assertEqual(len(dset_to_polars), 3) + self.assertListEqual(sorted(dset_to_polars.columns), sorted(dset.column_names)) + + for col_name in dset.column_names: + self.assertEqual(len(dset_to_polars[col_name]), dset.num_rows) + def test_to_parquet(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: # File path argument @@ -2791,6 +2840,17 @@ def test_format_pandas(self, in_memory): self.assertIsInstance(dset[:2], pd.DataFrame) self.assertIsInstance(dset["col_1"], pd.Series) + @require_polars + def test_format_polars(self, in_memory): + import polars as pl + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + dset.set_format("polars") + self.assertIsInstance(dset[0], pl.DataFrame) + self.assertIsInstance(dset[:2], pl.DataFrame) + self.assertIsInstance(dset["col_1"], pl.Series) + def test_transmit_format_single(self, in_memory): @transmit_format def my_single_transform(self, return_factory, *args, **kwargs): @@ -3057,6 +3117,35 @@ def test_from_pandas(self): features = Features({"col_1": Sequence(Value("string")), "col_2": Value("string")}) self.assertRaises(TypeError, Dataset.from_pandas, df, features=features) + @require_polars + def test_from_polars(self): + import polars as pl + + data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]} + df = pl.from_dict(data) + with Dataset.from_polars(df) as dset: + self.assertListEqual(dset["col_1"], data["col_1"]) + self.assertListEqual(dset["col_2"], data["col_2"]) + self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"]) + self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("large_string")})) + + features = Features({"col_1": Value("int64"), "col_2": Value("large_string")}) + with Dataset.from_polars(df, features=features) as dset: + self.assertListEqual(dset["col_1"], data["col_1"]) + self.assertListEqual(dset["col_2"], data["col_2"]) + self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"]) + self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("large_string")})) + + features = Features({"col_1": Value("int64"), "col_2": Value("large_string")}) + with Dataset.from_polars(df, features=features, info=DatasetInfo(features=features)) as dset: + self.assertListEqual(dset["col_1"], data["col_1"]) + self.assertListEqual(dset["col_2"], data["col_2"]) + self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"]) + self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("large_string")})) + + features = Features({"col_1": Sequence(Value("string")), "col_2": Value("large_string")}) + self.assertRaises(TypeError, Dataset.from_polars, df, features=features) + def test_from_dict(self): data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"], "col_3": pa.array([True, False, True, False])} with Dataset.from_dict(data) as dset: diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index 71d2f06d668..e6e801087e2 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -13,7 +13,13 @@ from datasets.iterable_dataset import IterableDataset from datasets.splits import NamedSplit -from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_tf, require_torch +from .utils import ( + assert_arrow_memory_doesnt_increase, + assert_arrow_memory_increases, + require_polars, + require_tf, + require_torch, +) class DatasetDictTest(TestCase): @@ -170,6 +176,24 @@ def test_set_format_pandas(self): self.assertEqual(dset_split[0]["col_2"].item(), "a") del dset + @require_polars + def test_set_format_polars(self): + import polars as pl + + dset = self._create_dummy_dataset_dict(multiple_columns=True) + dset.set_format(type="polars", columns=["col_1"]) + for dset_split in dset.values(): + self.assertEqual(len(dset_split[0].columns), 1) + self.assertIsInstance(dset_split[0], pl.DataFrame) + self.assertEqual(dset_split[0].shape, (1, 1)) + self.assertEqual(dset_split[0]["col_1"].item(), 3) + + dset.set_format(type="polars", columns=["col_1", "col_2"]) + for dset_split in dset.values(): + self.assertEqual(len(dset_split[0].columns), 2) + self.assertEqual(dset_split[0]["col_2"].item(), "a") + del dset + def test_set_transform(self): def transform(batch): return {k: [str(i).upper() for i in v] for k, v in batch.items()} diff --git a/tests/test_formatting.py b/tests/test_formatting.py index 9ac7d2c2f58..975b145768f 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -18,7 +18,7 @@ ) from datasets.table import InMemoryTable -from .utils import require_jax, require_pil, require_sndfile, require_tf, require_torch +from .utils import require_jax, require_pil, require_polars, require_sndfile, require_tf, require_torch class AnyArray: @@ -143,6 +143,60 @@ def test_pandas_extractor_temporal(self): self.assertTrue(isinstance(batch["d"][0], datetime.datetime)) self.assertTrue(pd.api.types.is_datetime64_any_dtype(batch["d"].dtype)) + @require_polars + def test_polars_extractor(self): + import polars as pl + + from datasets.formatting.polars_formatter import PolarsArrowExtractor + + pa_table = self._create_dummy_table() + extractor = PolarsArrowExtractor() + row = extractor.extract_row(pa_table) + self.assertIsInstance(row, pl.DataFrame) + assert pl.Series.eq(row["a"], pl.Series("a", _COL_A)[:1]).all() + assert pl.Series.eq(row["b"], pl.Series("b", _COL_B)[:1]).all() + col = extractor.extract_column(pa_table) + assert pl.Series.eq(col, pl.Series("a", _COL_A)).all() + batch = extractor.extract_batch(pa_table) + self.assertIsInstance(batch, pl.DataFrame) + assert pl.Series.eq(batch["a"], pl.Series("a", _COL_A)).all() + assert pl.Series.eq(batch["b"], pl.Series("b", _COL_B)).all() + + @require_polars + def test_polars_nested(self): + import polars as pl + + from datasets.formatting.polars_formatter import PolarsArrowExtractor + + pa_table = self._create_dummy_table().drop(["a", "b", "d"]) + extractor = PolarsArrowExtractor() + row = extractor.extract_row(pa_table) + self.assertEqual(row["c"][0][0].dtype, pl.Float64) + self.assertEqual(row["c"].dtype, pl.List(pl.List(pl.Float64))) + col = extractor.extract_column(pa_table) + self.assertEqual(col[0][0].dtype, pl.Float64) + self.assertEqual(col[0].dtype, pl.List(pl.Float64)) + self.assertEqual(col.dtype, pl.List(pl.List(pl.Float64))) + batch = extractor.extract_batch(pa_table) + self.assertEqual(batch["c"][0][0].dtype, pl.Float64) + self.assertEqual(batch["c"][0].dtype, pl.List(pl.Float64)) + self.assertEqual(batch["c"].dtype, pl.List(pl.List(pl.Float64))) + + @require_polars + def test_polars_temporal(self): + from datasets.formatting.polars_formatter import PolarsArrowExtractor + + pa_table = self._create_dummy_table().drop(["a", "b", "c"]) + extractor = PolarsArrowExtractor() + row = extractor.extract_row(pa_table) + self.assertTrue(row["d"].dtype.is_temporal()) + col = extractor.extract_column(pa_table) + self.assertTrue(isinstance(col[0], datetime.datetime)) + self.assertTrue(col.dtype.is_temporal()) + batch = extractor.extract_batch(pa_table) + self.assertTrue(isinstance(batch["d"][0], datetime.datetime)) + self.assertTrue(batch["d"].dtype.is_temporal()) + class LazyDictTest(TestCase): def _create_dummy_table(self): @@ -271,6 +325,25 @@ def test_pandas_formatter(self): pd.testing.assert_series_equal(batch["a"], pd.Series(_COL_A, name="a")) pd.testing.assert_series_equal(batch["b"], pd.Series(_COL_B, name="b")) + @require_polars + def test_polars_formatter(self): + import polars as pl + + from datasets.formatting import PolarsFormatter + + pa_table = self._create_dummy_table() + formatter = PolarsFormatter() + row = formatter.format_row(pa_table) + self.assertIsInstance(row, pl.DataFrame) + assert pl.Series.eq(row["a"], pl.Series("a", _COL_A)[:1]).all() + assert pl.Series.eq(row["b"], pl.Series("b", _COL_B)[:1]).all() + col = formatter.format_column(pa_table) + assert pl.Series.eq(col, pl.Series("a", _COL_A)).all() + batch = formatter.format_batch(pa_table) + self.assertIsInstance(batch, pl.DataFrame) + assert pl.Series.eq(batch["a"], pl.Series("a", _COL_A)).all() + assert pl.Series.eq(batch["b"], pl.Series("b", _COL_B)).all() + @require_torch def test_torch_formatter(self): import torch diff --git a/tests/test_load.py b/tests/test_load.py index 9261ed7db2f..2516c794c21 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -838,10 +838,10 @@ def test_HubDatasetModuleFactoryWithParquetExport(self): ) assert module_factory_result.builder_configs_parameters.builder_configs[0].data_files == { "train": [ - "hf://datasets/hf-internal-testing/dataset_with_script@da4ed81df5a1bcd916043c827b75994de8ef7eda/default/train/0000.parquet" + "hf://datasets/hf-internal-testing/dataset_with_script@8f965694d611974ef8661618ada1b5aeb1072915/default/train/0000.parquet" ], "validation": [ - "hf://datasets/hf-internal-testing/dataset_with_script@da4ed81df5a1bcd916043c827b75994de8ef7eda/default/validation/0000.parquet" + "hf://datasets/hf-internal-testing/dataset_with_script@8f965694d611974ef8661618ada1b5aeb1072915/default/validation/0000.parquet" ], } diff --git a/tests/utils.py b/tests/utils.py index 1fe38e82774..0dc3d83533b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -141,6 +141,18 @@ def require_torch(test_case): return test_case +def require_polars(test_case): + """ + Decorator marking a test that requires Polars. + + These tests are skipped when Polars isn't installed. + + """ + if not config.POLARS_AVAILABLE: + test_case = unittest.skip("test requires Polars")(test_case) + return test_case + + def require_tf(test_case): """ Decorator marking a test that requires TensorFlow.