Skip to content

Commit

Permalink
feat(python): make PlFlavor singletons with hidden version
Browse files Browse the repository at this point in the history
  • Loading branch information
ruihe774 committed Jul 6, 2024
1 parent 1fe3059 commit e10cc94
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 29 deletions.
38 changes: 19 additions & 19 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ def item(self, row: int | None = None, column: int | str | None = None) -> Any:
)
return s.get_index_signed(row)

def to_arrow(self, *, future: Flavor = Flavor.Compatible) -> pa.Table:
def to_arrow(self, *, future: Flavor | None = None) -> pa.Table:
"""
Collect the underlying arrow arrays in an Arrow Table.
Expand Down Expand Up @@ -1415,8 +1415,10 @@ def to_arrow(self, *, future: Flavor = Flavor.Compatible) -> pa.Table:
if not self.width: # 0x0 dataframe, cannot infer schema from batches
return pa.table({})

if isinstance(future, Flavor):
future = future.value # type: ignore[assignment]
if future is None:
future = False # type: ignore[assignment]
elif isinstance(future, Flavor):
future = future._version # type: ignore[attr-defined]

record_batches = self._df.to_arrow(future)
return pa.Table.from_batches(record_batches)
Expand Down Expand Up @@ -3291,7 +3293,7 @@ def write_ipc(
file: None,
*,
compression: IpcCompression = "uncompressed",
future: Flavor = Flavor.Future1,
future: Flavor | None = None,
) -> BytesIO: ...

@overload
Expand All @@ -3300,15 +3302,15 @@ def write_ipc(
file: str | Path | IO[bytes],
*,
compression: IpcCompression = "uncompressed",
future: Flavor = Flavor.Future1,
future: Flavor | None = None,
) -> None: ...

def write_ipc(
self,
file: str | Path | IO[bytes] | None,
*,
compression: IpcCompression = "uncompressed",
future: Flavor = Flavor.Future1,
future: Flavor | None = None,
) -> BytesIO | None:
"""
Write to Arrow IPC binary stream or Feather file.
Expand Down Expand Up @@ -3345,11 +3347,10 @@ def write_ipc(
elif isinstance(file, (str, Path)):
file = normalize_filepath(file)

if isinstance(future, Flavor):
future = future.value # type: ignore[assignment]
elif future is None:
# this is for backward compatibility
future = True
if future is None:
future = True # type: ignore[assignment]
elif isinstance(future, Flavor):
future = future._version # type: ignore[attr-defined]

if compression is None:
compression = "uncompressed"
Expand All @@ -3363,7 +3364,7 @@ def write_ipc_stream(
file: None,
*,
compression: IpcCompression = "uncompressed",
future: Flavor = Flavor.Future1,
future: Flavor | None = None,
) -> BytesIO: ...

@overload
Expand All @@ -3372,15 +3373,15 @@ def write_ipc_stream(
file: str | Path | IO[bytes],
*,
compression: IpcCompression = "uncompressed",
future: Flavor = Flavor.Future1,
future: Flavor | None = None,
) -> None: ...

def write_ipc_stream(
self,
file: str | Path | IO[bytes] | None,
*,
compression: IpcCompression = "uncompressed",
future: Flavor = Flavor.Future1,
future: Flavor | None = None,
) -> BytesIO | None:
"""
Write to Arrow IPC record batch stream.
Expand Down Expand Up @@ -3417,11 +3418,10 @@ def write_ipc_stream(
elif isinstance(file, (str, Path)):
file = normalize_filepath(file)

if isinstance(future, Flavor):
future = future.value # type: ignore[assignment]
elif future is None:
# this is for backward compatibility
future = True
if future is None:
future = True # type: ignore[assignment]
elif isinstance(future, Flavor):
future = future._version # type: ignore[attr-defined]

if compression is None:
compression = "uncompressed"
Expand Down
35 changes: 29 additions & 6 deletions py-polars/polars/interchange/protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from enum import Enum, IntEnum
from enum import IntEnum
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -259,11 +259,22 @@ class CopyNotAllowedError(RuntimeError):
"""Exception raised when a copy is required, but `allow_copy` is set to `False`."""


class Flavor(Enum):
"""Data structure versioning."""
class Flavor:
"""Data structure flavor."""

Compatible = 0
Future1 = 1
def __init__(self) -> None:
msg = "it is not allowed to create a Flavor object"
raise TypeError(msg)

@staticmethod
def _with_version(version: int) -> Flavor:
flavor = Flavor.__new__(Flavor)
flavor._version = version # type: ignore[attr-defined]
return flavor

@staticmethod
def _highest() -> Flavor:
return Flavor._future1 # type: ignore[attr-defined]

@staticmethod
def highest() -> Flavor:
Expand All @@ -275,4 +286,16 @@ def highest() -> Flavor:
at any point without it being considered a breaking change.
"""
issue_unstable_warning("Using the highest flavor is considered unstable.")
return Flavor.Future1
return Flavor._highest()

@staticmethod
def compatible() -> Flavor:
"""Get the flavor that is compatible with older arrow implementation."""
return Flavor._compatible # type: ignore[attr-defined]

def __repr__(self) -> str:
return f"<{self.__class__.__module__}.{self.__class__.__qualname__}: {self._version}>" # type: ignore[attr-defined]


Flavor._compatible = Flavor._with_version(0) # type: ignore[attr-defined]
Flavor._future1 = Flavor._with_version(1) # type: ignore[attr-defined]
17 changes: 14 additions & 3 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,15 @@ def _from_buffers(
validity = validity._s
return cls._from_pyseries(PySeries._from_buffers(dtype, data, validity))

def _highest_flavor(self) -> int:
"""
Get the highest supported flavor version.
This is only used by pyo3-polars,
and it is simpler not to make it a static method.
"""
return Flavor._highest()._version # type: ignore[attr-defined]

@property
def dtype(self) -> DataType:
"""
Expand Down Expand Up @@ -4343,7 +4352,7 @@ def to_torch(self) -> torch.Tensor:
# tensor.rename(self.name)
return tensor

def to_arrow(self, *, future: Flavor = Flavor.Compatible) -> pa.Array:
def to_arrow(self, *, future: Flavor | None = None) -> pa.Array:
"""
Return the underlying Arrow array.
Expand All @@ -4366,8 +4375,10 @@ def to_arrow(self, *, future: Flavor = Flavor.Compatible) -> pa.Array:
3
]
"""
if isinstance(future, Flavor):
future = future.value # type: ignore[assignment]
if future is None:
future = False # type: ignore[assignment]
elif isinstance(future, Flavor):
future = future._version # type: ignore[attr-defined]
return self._s.to_arrow(future)

def to_pandas(
Expand Down
12 changes: 11 additions & 1 deletion py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.exceptions import ComputeError, UnstableWarning
from polars.interchange.protocol import Flavor
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -702,3 +703,12 @@ def test_from_numpy_different_resolution_invalid() -> None:
pl.Series(
np.array(["2020-01-01"], dtype="datetime64[s]"), dtype=pl.Datetime("us")
)


def test_highest_flavor(monkeypatch: pytest.MonkeyPatch) -> None:
# change these if flavor version bumped
monkeypatch.setenv("POLARS_WARN_UNSTABLE", "1")
assert Flavor.compatible()._version == 0 # type: ignore[attr-defined]
with pytest.warns(UnstableWarning):
assert Flavor.highest()._version == 1 # type: ignore[attr-defined]
assert pl.Series([1])._highest_flavor() == 1

0 comments on commit e10cc94

Please sign in to comment.