Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add polars compatibility #6531

Merged
merged 42 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
83d1fbb
Add Polars support for data formatting and conversion
Dec 24, 2023
fec2bb3
Update Polars availability check in config.py
Dec 24, 2023
622e54e
Merge branch 'add-polars-compatibility' of github.com:psmyth94/datase…
Dec 24, 2023
82b9b7c
Merge branch 'add-polars-compatibility' of github.com:psmyth94/datase…
Dec 24, 2023
aae3f5a
Merge branch 'add-polars-compatibility' of github.com:psmyth94/datase…
Dec 24, 2023
7374e99
added to_polars
Mar 6, 2024
408b9d6
changed the logic of importing polars if not already called
Mar 6, 2024
39a5c56
Remove to and from_polars from table.py in order to maintain pa.table…
Mar 6, 2024
3aa7081
Merge branch 'main' into add-polars-compatibility
psmyth94 Mar 6, 2024
2f10384
fix unused import
Mar 6, 2024
a623f51
Merge branch 'add-polars-compatibility' of github.com:psmyth94/datase…
Mar 6, 2024
12fef57
fixed code formatting with ruff
Mar 6, 2024
a57fcbe
fix formatting issues with ruff
Mar 7, 2024
912c437
fix formatting issues using ruff
Mar 7, 2024
ce7c3c5
add tests for polars formatting
Mar 7, 2024
1b28d85
removed using InMemoryTable classmethod to convert polars to Table
Mar 7, 2024
eb4d7ce
added test for polars conversion
Mar 7, 2024
417f9ad
added missing ruff fixes
Mar 7, 2024
19e5d80
add polars in test dependencies
Mar 7, 2024
7c835a4
Fixed not executing default write method due to nested polars check.
Mar 7, 2024
d0582f9
Merge branch 'main' into add-polars-compatibility
psmyth94 Mar 7, 2024
fa51fd2
Update src/datasets/arrow_dataset.py
psmyth94 Mar 7, 2024
d09839c
Update src/datasets/arrow_dataset.py
psmyth94 Mar 7, 2024
1c6d2a5
Fix Polars DataFrame conversion bug
Mar 7, 2024
40614ee
Merge branch 'add-polars-compatibility' of github.com:psmyth94/datase…
Mar 7, 2024
7d6224b
Fix DataFrame conversion in arrow_dataset.py
Mar 7, 2024
a301eb3
Fix variable name in arrow_dataset.py
Mar 7, 2024
d062b57
Fix write_table to write_row in Dataset class
Mar 7, 2024
1dbdc80
fix formatting with ruff
Mar 7, 2024
1b9e450
Update polars dependency to include timezone support
Mar 7, 2024
23329d5
Remove polars in EXTRAS_REQUIRE
Mar 7, 2024
f4361cc
Replace deprecated method
Mar 7, 2024
53f471a
perform cleanup after use
Mar 7, 2024
52fd448
Merge branch 'main' into add-polars-compatibility
psmyth94 Mar 7, 2024
b898b57
remove unused import
Mar 7, 2024
4bacf3b
Add garbage collection to test_to_polars method
Mar 7, 2024
358b2cb
Remove unused import and unnecessary code in test_to_polars method
Mar 7, 2024
c00efad
Add additional args for to_polars method
Mar 7, 2024
ddaab5b
Fixed unclosed links to dataset file
Mar 7, 2024
a87998c
ruff cleanup
Mar 7, 2024
d1acc92
even ruffier cleanup
Mar 7, 2024
7ee3fdf
changed hash to reflect new SHA for ref/convert/parquet
HuggingFaceDocBuilder Mar 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
"transformers",
"typing-extensions>=4.6.1", # due to conflict between apache-beam and pydantic
"zstandard",
"polars[timezone]>=0.20.0",
]


Expand Down
109 changes: 109 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
if TYPE_CHECKING:
import sqlite3

import polars as pl
import pyspark
import sqlalchemy

Expand Down Expand Up @@ -868,6 +869,48 @@ def from_pandas(
table = table.cast(features.arrow_schema)
return cls(table, info=info, split=split)

@classmethod
def from_polars(
cls,
df: "pl.DataFrame",
features: Optional[Features] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
) -> "Dataset":
"""
Collect the underlying arrow arrays in an Arrow Table.

This operation is mostly zero copy.

Data types that do copy:
* CategoricalType

Args:
df (`polars.DataFrame`): DataFrame to convert to Arrow Table
features (`Features`, optional): Dataset features.
info (`DatasetInfo`, optional): Dataset information, like description, citation, etc.
split (`NamedSplit`, optional): Name of the dataset split.

Examples:
```py
>>> ds = Dataset.from_polars(df)
```
"""
if info is not None and features is not None and info.features != features:
raise ValueError(
f"Features specified in `features` and `info.features` can't be different:\n{features}\n{info.features}"
)
features = features if features is not None else info.features if info is not None else None
if info is None:
info = DatasetInfo()
info.features = features
table = InMemoryTable(df.to_arrow())
if features is not None:
# more expensive cast than InMemoryTable.from_polars(..., schema=features.arrow_schema)
# needed to support the str to Audio conversion for instance
table = table.cast(features.arrow_schema)
return cls(table, info=info, split=split)

@classmethod
def from_dict(
cls,
Expand Down Expand Up @@ -3319,6 +3362,10 @@ def validate_function_output(processed_inputs, indices):
)
elif isinstance(indices, list) and isinstance(processed_inputs, Mapping):
allowed_batch_return_types = (list, np.ndarray, pd.Series)
if config.POLARS_AVAILABLE and "polars" in sys.modules:
import polars as pl

allowed_batch_return_types += (pl.Series, pl.DataFrame)
if config.TF_AVAILABLE and "tensorflow" in sys.modules:
import tensorflow as tf

Expand Down Expand Up @@ -3438,6 +3485,10 @@ def init_buffer_and_writer():
# If `update_data` is True after processing the first example/batch, initalize these resources with `init_buffer_and_writer`
buf_writer, writer, tmp_file = None, None, None

# Check if Polars is available and import it if so
if config.POLARS_AVAILABLE and "polars" in sys.modules:
import polars as pl

# Optionally initialize the writer as a context manager
with contextlib.ExitStack() as stack:
try:
Expand All @@ -3464,6 +3515,12 @@ def init_buffer_and_writer():
writer.write_row(example)
elif isinstance(example, pd.DataFrame):
writer.write_row(pa.Table.from_pandas(example))
elif (
config.POLARS_AVAILABLE
and "polars" in sys.modules
and isinstance(example, pl.DataFrame)
):
writer.write_row(example.to_arrow())
else:
writer.write(example)
num_examples_progress_update += 1
Expand Down Expand Up @@ -3497,6 +3554,10 @@ def init_buffer_and_writer():
writer.write_table(batch)
elif isinstance(batch, pd.DataFrame):
writer.write_table(pa.Table.from_pandas(batch))
elif (
config.POLARS_AVAILABLE and "polars" in sys.modules and isinstance(batch, pl.DataFrame)
):
writer.write_table(batch.to_arrow())
else:
writer.write_batch(batch)
num_examples_progress_update += num_examples_in_batch
Expand Down Expand Up @@ -4949,6 +5010,54 @@ def to_pandas(
for offset in range(0, len(self), batch_size)
)

def to_polars(
self, batch_size: Optional[int] = None, batched: bool = False
) -> Union["pl.DataFrame", Iterator["pl.DataFrame"]]:
"""Returns the dataset as a `polars.DataFrame`. Can also return a generator for large datasets.

Args:
batched (`bool`):
Set to `True` to return a generator that yields the dataset as batches
of `batch_size` rows. Defaults to `False` (returns the whole datasets once).
batch_size (`int`, *optional*):
The size (number of rows) of the batches if `batched` is `True`.
Defaults to `genomicsml.datasets.config.DEFAULT_MAX_BATCH_SIZE`.

Returns:
`polars.DataFrame` or `Iterator[polars.DataFrame]`

Example:

```py
>>> ds.to_polars()
```
"""
if config.POLARS_AVAILABLE:
import polars as pl

if not batched:
return pl.from_arrow(
query_table(
table=self._data,
key=slice(0, len(self)),
indices=self._indices if self._indices is not None else None,
)
)
else:
batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
return (
pl.from_arrow(
query_table(
table=self._data,
key=slice(offset, offset + batch_size),
indices=self._indices if self._indices is not None else None,
)
)
for offset in range(0, len(self), batch_size)
)
else:
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.")

def to_parquet(
self,
path_or_buf: Union[PathLike, BinaryIO],
Expand Down
10 changes: 10 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@
else:
logger.info("Disabling PyTorch because USE_TF is set")

POLARS_VERSION = "N/A"
POLARS_AVAILABLE = importlib.util.find_spec("polars") is not None

if POLARS_AVAILABLE:
try:
POLARS_VERSION = version.parse(importlib.metadata.version("polars"))
logger.info(f"Polars version {POLARS_VERSION} available.")
except importlib.metadata.PackageNotFoundError:
pass

TF_VERSION = "N/A"
TF_AVAILABLE = False

Expand Down
8 changes: 8 additions & 0 deletions src/datasets/formatting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ def _register_unavailable_formatter(
_register_formatter(PandasFormatter, "pandas", aliases=["pd"])
_register_formatter(CustomFormatter, "custom")

if config.POLARS_AVAILABLE:
from .polars_formatter import PolarsFormatter

_register_formatter(PolarsFormatter, "polars", aliases=["pl"])
else:
_polars_error = ValueError("Polars needs to be installed to be able to return Polars dataframes.")
_register_unavailable_formatter(_polars_error, "polars", aliases=["pl"])

if config.TORCH_AVAILABLE:
from .torch_formatter import TorchFormatter

Expand Down
122 changes: 122 additions & 0 deletions src/datasets/formatting/polars_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2020 The HuggingFace Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from collections.abc import Mapping
from functools import partial
from typing import TYPE_CHECKING, Optional

import pyarrow as pa

from .. import config
from ..features import Features
from ..features.features import decode_nested_example
from ..utils.py_utils import no_op_if_value_is_null
from .formatting import BaseArrowExtractor, TensorFormatter


if TYPE_CHECKING:
import polars as pl


class PolarsArrowExtractor(BaseArrowExtractor["pl.DataFrame", "pl.Series", "pl.DataFrame"]):
def extract_row(self, pa_table: pa.Table) -> "pl.DataFrame":
if config.POLARS_AVAILABLE:
if "polars" not in sys.modules:
import polars
else:
polars = sys.modules["polars"]

return polars.from_arrow(pa_table.slice(length=1))
else:
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.")

def extract_column(self, pa_table: pa.Table) -> "pl.Series":
if config.POLARS_AVAILABLE:
if "polars" not in sys.modules:
import polars
else:
polars = sys.modules["polars"]

return polars.from_arrow(pa_table.select([0]))[pa_table.column_names[0]]
else:
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.")

def extract_batch(self, pa_table: pa.Table) -> "pl.DataFrame":
if config.POLARS_AVAILABLE:
if "polars" not in sys.modules:
import polars
else:
polars = sys.modules["polars"]

return polars.from_arrow(pa_table)
else:
raise ValueError("Polars needs to be installed to be able to return Polars dataframes.")


class PolarsFeaturesDecoder:
def __init__(self, features: Optional[Features]):
self.features = features
import polars as pl # noqa: F401 - import pl at initialization

def decode_row(self, row: "pl.DataFrame") -> "pl.DataFrame":
decode = (
{
column_name: no_op_if_value_is_null(partial(decode_nested_example, feature))
for column_name, feature in self.features.items()
if self.features._column_requires_decoding[column_name]
}
if self.features
else {}
)
if decode:
row[list(decode.keys())] = row.map_rows(decode)
return row

def decode_column(self, column: "pl.Series", column_name: str) -> "pl.Series":
decode = (
no_op_if_value_is_null(partial(decode_nested_example, self.features[column_name]))
if self.features and column_name in self.features and self.features._column_requires_decoding[column_name]
else None
)
if decode:
column = column.map_elements(decode)
return column

def decode_batch(self, batch: "pl.DataFrame") -> "pl.DataFrame":
return self.decode_row(batch)


class PolarsFormatter(TensorFormatter[Mapping, "pl.DataFrame", Mapping]):
def __init__(self, features=None, **np_array_kwargs):
super().__init__(features=features)
self.np_array_kwargs = np_array_kwargs
self.polars_arrow_extractor = PolarsArrowExtractor
self.polars_features_decoder = PolarsFeaturesDecoder(features)
import polars as pl # noqa: F401 - import pl at initialization

def format_row(self, pa_table: pa.Table) -> "pl.DataFrame":
row = self.polars_arrow_extractor().extract_row(pa_table)
row = self.polars_features_decoder.decode_row(row)
return row

def format_column(self, pa_table: pa.Table) -> "pl.Series":
column = self.polars_arrow_extractor().extract_column(pa_table)
column = self.polars_features_decoder.decode_column(column, pa_table.column_names[0])
return column

def format_batch(self, pa_table: pa.Table) -> "pl.DataFrame":
row = self.polars_arrow_extractor().extract_batch(pa_table)
row = self.polars_features_decoder.decode_batch(row)
return row
Loading
Loading