From b6531253d270e39c0603fff0b4909c4cf813b85c Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 16 May 2024 14:26:37 +0400 Subject: [PATCH 01/11] feat(python): Add new `to_jax` method to support export to jax arrays from `DataFrame` --- py-polars/polars/dataframe/frame.py | 270 ++++++++++++++++-- py-polars/polars/series/series.py | 49 +++- py-polars/polars/type_aliases.py | 1 + py-polars/pyproject.toml | 4 +- py-polars/requirements-ci.txt | 2 + py-polars/tests/docs/run_doctest.py | 1 + py-polars/tests/unit/ml/__init__.py | 0 py-polars/tests/unit/ml/test_to_jax.py | 129 +++++++++ .../unit/{dataframe => ml}/test_to_torch.py | 62 ++-- 9 files changed, 469 insertions(+), 49 deletions(-) create mode 100644 py-polars/tests/unit/ml/__init__.py create mode 100644 py-polars/tests/unit/ml/test_to_jax.py rename py-polars/tests/unit/{dataframe => ml}/test_to_torch.py (83%) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index f5649d28f6dc..06ee6f38f1aa 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -72,6 +72,7 @@ INTEGER_DTYPES, N_INFER_DEFAULT, Boolean, + Float32, Float64, Int32, Int64, @@ -102,7 +103,7 @@ from polars.functions import col, lit from polars.selectors import _expand_selector_dicts, _expand_selectors from polars.slice import PolarsSlice -from polars.type_aliases import DbWriteMode, TorchExportType +from polars.type_aliases import DbWriteMode, JaxExportType, TorchExportType with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import dtype_str_repr as _dtype_str_repr @@ -115,6 +116,7 @@ from typing import Literal import deltalake + import jax import torch from hvplot.plotting.core import hvPlotTabularPolars from xlsxwriter import Workbook @@ -1516,7 +1518,7 @@ def to_numpy( However, the C-like order might be more appropriate to use for downstream applications to prevent cloning data, e.g. when reshaping into a one-dimensional array. Note that this option only takes effect if - `structured` is set to `False` and the DataFrame dtypes allow for a + `structured` is set to `False` and the DataFrame dtypes allow a global dtype for all columns. allow_copy Allow memory to be copied to perform the conversion. If set to `False`, @@ -1609,6 +1611,200 @@ def raise_on_copy(msg: str) -> None: return out + @overload + def to_jax( + self, + return_type: Literal["array"] = ..., + *, + device: jax.Device | str | None = ..., + label: str | Expr | Sequence[str | Expr] | None = ..., + features: str | Expr | Sequence[str | Expr] | None = ..., + dtype: PolarsDataType | None = ..., + order: IndexOrder = ..., + ) -> jax.Array: ... + + @overload + def to_jax( + self, + return_type: Literal["dict"], + *, + device: jax.Device | str | None = ..., + label: str | Expr | Sequence[str | Expr] | None = ..., + features: str | Expr | Sequence[str | Expr] | None = ..., + dtype: PolarsDataType | None = ..., + order: IndexOrder = ..., + ) -> dict[str, jax.Array]: ... + + def to_jax( + self, + return_type: JaxExportType = "array", + *, + device: jax.Device | str | None = None, + label: str | Expr | Sequence[str | Expr] | None = None, + features: str | Expr | Sequence[str | Expr] | None = None, + dtype: PolarsDataType | None = None, + order: IndexOrder = "fortran", + ) -> jax.Array | dict[str, jax.Array]: + """ + Convert DataFrame to a 2D Jax Array, or dict of Jax Arrays. + + Parameters + ---------- + return_type : {"array", "dict"} + Set return type; a 2D Jax Array, or dict of Jax Arrays. + device + Specify the jax `Device` on which the array will be created; can provide + a string (such as "cpu", "gpu", or "tpu") in which case the device is + retrieved as `jax.devices(string)[0]`. For more specific control you + can supply the instantiated `Device` directly. If None, arrays are + created on the default device. + label + One or more column names, expressions, or selectors that label the feature + data; results in a `{"label": ..., "features": ...}` dict being returned + when `return_type` is "dict" instead of a `{"col": array, }` dict. + features + One or more column names, expressions, or selectors that contain the feature + data; if omitted, all columns that are not designated as part of the label + are used. Only applies when `return_type` is "dict". + dtype + Unify the dtype of all returned arrays; this casts any column that is + not already of the required dtype before converting to Array. Note that + export will be single-precision (32bit) unless the Jax config/environment + directs otherwise (eg: "jax_enable_x64" was set True in the config object + at startup, or "JAX_ENABLE_X64" is set to "1" in the environment). + order : {"c", "fortran"} + The index order of the returned Jax array, either C-like or Fortran-like. + + See Also + -------- + to_dummies + to_numpy + to_torch + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "lbl": [0, 1, 2, 3], + ... "feat1": [1, 0, 0, 1], + ... "feat2": [1.5, -0.5, 0.0, -2.25], + ... } + ... ) + + Standard return type (2D Array), on the standard device: + + >>> df.to_jax() + Array([[ 0. , 1. , 1.5 ], + [ 1. , 0. , -0.5 ], + [ 2. , 0. , 0. ], + [ 3. , 1. , -2.25]], dtype=float32) + + Create the Array on the default GPU device: + + >>> a = df.to_jax(device="gpu") # doctest: +SKIP + >>> a.device() # doctest: +SKIP + GpuDevice(id=0, process_index=0) + + Create the Array on a specific GPU device: + + >>> gpu_device = jax.devices("gpu")[1]) # doctest: +SKIP + >>> a = df.to_jax(device=gpu_device) # doctest: +SKIP + >>> a.device() # doctest: +SKIP + GpuDevice(id=1, process_index=0) + + As a dictionary of individual Arrays: + + >>> df.to_jax("dict") + {'lbl': Array([0, 1, 2, 3], dtype=int32), + 'feat1': Array([1, 0, 0, 1], dtype=int32), + 'feat2': Array([ 1.5 , -0.5 , 0. , -2.25], dtype=float32)} + + As a "label" and "features" dictionary; note that as "features" is not + declared, it defaults to all the columns that are not in "label": + + >>> df.to_jax("dict", label="lbl") + {'label': Array([[0], + [1], + [2], + [3]], dtype=int32), + 'features': Array([[ 1. , 1.5 ], + [ 0. , -0.5 ], + [ 0. , 0. ], + [ 1. , -2.25]], dtype=float32)} + + As a "label" and "features" dictionary where each is designated using + a selector expression (which can also be used to cast the data if the + label and features are better-represented with different dtypes): + + >>> import polars.selectors as cs + >>> df.to_jax( + ... return_type="dict", + ... features=cs.float(), + ... label=pl.col("lbl").cast(pl.UInt8), + ... ) + {'label': Array([[0], + [1], + [2], + [3]], dtype=uint8), + 'features': Array([[ 1.5 ], + [-0.5 ], + [ 0. ], + [-2.25]], dtype=float32)} + """ + if return_type != "dict" and (label is not None or features is not None): + msg = "`label` and `features` only apply when `return_type` is 'dict'" + raise ValueError(msg) + elif return_type == "dict" and label is None and features is not None: + msg = "`label` is required if setting `features` when `return_type='dict'" + raise ValueError(msg) + + jx = import_optional( + "jax", + install_message="Please see `https://jax.readthedocs.io/en/latest/installation.html` " + "for specific installation recommendations for the Jax package", + ) + enabled_double_precision = jx.config.jax_enable_x64 or bool( + int(os.environ.get("JAX_ENABLE_X64", "1")) + ) + if dtype: + frame = self.cast(dtype) + elif not enabled_double_precision: + # enforce single-precision unless environment/config directs otherwise + frame = self.cast({Float64: Float32, Int64: Int32, UInt64: UInt32}) + else: + frame = self + + if isinstance(device, str): + device = jx.devices(device)[0] + + with contextlib.nullcontext() if device is None else jx.default_device(device): + if return_type == "array": + return jx.numpy.asarray( + # note: jax arrays are immutable, so can avoid a copy (vs torch) + a=frame.to_numpy(writable=False, use_pyarrow=False, order=order), + order="K", + ) + elif return_type == "dict": + if label is not None: + # return a {"label": array(s), "features": array(s)} dict + label_frame = frame.select(label) + features_frame = ( + frame.select(features) + if features is not None + else frame.drop(*label_frame.columns) + ) + return { + "label": label_frame.to_jax(), + "features": features_frame.to_jax(), + } + else: + # return a {"col": array} dict + return {srs.name: srs.to_jax() for srs in frame} + else: + valid_jax_types = ", ".join(get_args(JaxExportType)) + msg = f"invalid `return_type`: {return_type!r}\nExpected one of: {valid_jax_types}" + raise ValueError(msg) + @overload def to_torch( self, @@ -1648,31 +1844,35 @@ def to_torch( dtype: PolarsDataType | None = None, ) -> torch.Tensor | dict[str, torch.Tensor] | PolarsDataset: """ - Convert DataFrame to a 2D PyTorch tensor, Dataset, or dict of Tensors. + Convert DataFrame to a 2D PyTorch Tensor, Dataset, or dict of Tensors. .. versionadded:: 0.20.23 Parameters ---------- return_type : {"tensor", "dataset", "dict"} - Set return type; a 2D PyTorch tensor, PolarsDataset (a frame-specialized + Set return type; a 2D PyTorch Tensor, PolarsDataset (a frame-specialized TensorDataset), or dict of Tensors. label One or more column names, expressions, or selectors that label the feature data; when `return_type` is "dataset", the PolarsDataset will return `(features, label)` tensor tuples for each row. Otherwise, it returns - `(features,)` tensor tuples where the feature contains all the row data; - note that setting this parameter with any other result type will raise an - informative error. + `(features,)` tensor tuples where the feature contains all the row data. features One or more column names, expressions, or selectors that contain the feature data; if omitted, all columns that are not designated as part of the label - are used. This parameter is a no-op for return-types other than "dataset". + are used. dtype - Unify the dtype of all returned tensors; this casts any frame Series - that are not of the required dtype before converting to tensor. This - includes the label column *unless* the label is an expression (such - as `pl.col("label_column").cast(pl.Int16)`). + Unify the dtype of all returned tensors; this casts any column that is + not of the required dtype before converting to Tensor. This includes + the label column *unless* the label is an expression (such as + `pl.col("label_column").cast(pl.Int16)`). + + See Also + -------- + to_dummies + to_jax + to_numpy Examples -------- @@ -1699,6 +1899,19 @@ def to_torch( 'feat1': tensor([1, 0, 0, 1]), 'feat2': tensor([ 1.5000, -0.5000, 0.0000, -2.2500], dtype=torch.float64)} + As a "label" and "features" dictionary; note that as "features" is not + declared, it defaults to all the columns that are not in "label": + + >>> df.to_torch("dict", label="lbl", dtype=pl.Float32) + {'label': tensor([[0.], + [1.], + [2.], + [3.]]), + 'features': tensor([[ 1.0000, 1.5000], + [ 0.0000, -0.5000], + [ 0.0000, 0.0000], + [ 1.0000, -2.2500]])} + As a PolarsDataset, with f64 supertype: >>> ds = df.to_torch("dataset", dtype=pl.Float64) @@ -1711,7 +1924,7 @@ def to_torch( (tensor([[ 0.0000, 1.0000, 1.5000], [ 3.0000, 1.0000, -2.2500]], dtype=torch.float64),) - As a convenience the PolarsDataset can opt-in to half-precision data + As a convenience the PolarsDataset can opt in to half-precision data for experimentation (usually this would be set on the model/pipeline): >>> list(ds.half()) @@ -1735,7 +1948,7 @@ def to_torch( supported). >>> ds = df.to_torch( - ... "dataset", + ... return_type="dataset", ... dtype=pl.Float32, ... label=pl.col("lbl").cast(pl.Int16), ... ) @@ -1760,8 +1973,13 @@ def to_torch( ... batch_size=64, ... ) # doctest: +SKIP """ - if return_type != "dataset" and (label is not None or features is not None): - msg = "the `label` and `features` parameters can only be set when `return_type='dataset'`" + if return_type not in ("dataset", "dict") and ( + label is not None or features is not None + ): + msg = "`label` and `features` only apply when `return_type` is 'dataset' or 'dict'" + raise ValueError(msg) + elif return_type == "dict" and label is None and features is not None: + msg = "`label` is required if setting `features` when `return_type='dict'" raise ValueError(msg) torch = import_optional("torch") @@ -1774,10 +1992,28 @@ def to_torch( frame = self.cast(to_dtype) # type: ignore[arg-type] if return_type == "tensor": + # note: torch tensors are not immutable, so we must consider them writable return torch.from_numpy(frame.to_numpy(writable=True, use_pyarrow=False)) + elif return_type == "dict": - return {srs.name: srs.to_torch() for srs in frame} + if label is not None: + # return a {"label": tensor(s), "features": tensor(s)} dict + label_frame = frame.select(label) + features_frame = ( + frame.select(features) + if features is not None + else frame.drop(*label_frame.columns) + ) + return { + "label": label_frame.to_torch(), + "features": features_frame.to_torch(), + } + else: + # return a {"col": tensor} dict + return {srs.name: srs.to_torch() for srs in frame} + elif return_type == "dataset": + # return a torch Dataset object from polars.ml.torch import PolarsDataset return PolarsDataset(frame, label=label, features=features) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index a3aaa3d001cb..858ef718f807 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -2,6 +2,8 @@ import contextlib import math +import os +from contextlib import nullcontext from datetime import date, datetime, time, timedelta from decimal import Decimal as PyDecimal from typing import ( @@ -66,6 +68,7 @@ Decimal, Duration, Enum, + Float32, Float64, Int8, Int16, @@ -117,6 +120,7 @@ if TYPE_CHECKING: import sys + import jax import torch from hvplot.plotting.core import hvPlotTabularPolars @@ -4472,9 +4476,52 @@ def to_numpy( return self._s.to_numpy(allow_copy=allow_copy, writable=writable) + def to_jax(self, device: jax.Device | str | None = None) -> jax.Array: + """ + Convert this Series to a Jax Array. + + Parameters + ---------- + device + Specify the jax `Device` on which the array will be created; can provide + a string (such as "cpu", "gpu", or "tpu") in which case the device is + retrieved as `jax.devices(string)[0]`. For more specific control you + can supply the instantiated `Device` directly. If None, arrays are + created on the default device. + + Examples + -------- + >>> s = pl.Series("x", [10.5, 0.0, -10.0, 5.5]) + >>> s.to_jax() + Array([ 10.5, 0. , -10. , 5.5], dtype=float32) + """ + jx = import_optional( + "jax", + install_message="Please see `https://jax.readthedocs.io/en/latest/installation.html` " + "for specific installation recommendations for the Jax package", + ) + if isinstance(device, str): + device = jx.devices(device)[0] + if ( + jx.config.jax_enable_x64 + or bool(int(os.environ.get("JAX_ENABLE_X64", "1"))) + or self.dtype not in {Float64, Int64, UInt64} + ): + srs = self + else: + single_precision = {Float64: Float32, Int64: Int32, UInt64: UInt32} + srs = self.cast(single_precision[self.dtype]) # type: ignore[index] + + with nullcontext() if device is None else jx.default_device(device): + return jx.numpy.asarray( + # note: jax arrays are immutable, so can avoid a copy (vs torch) + a=srs.to_numpy(writable=False, use_pyarrow=False), + order="K", + ) + def to_torch(self) -> torch.Tensor: """ - Convert this Series to a PyTorch tensor. + Convert this Series to a PyTorch Tensor. Examples -------- diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index b57dcee1f5a3..92daf0b2dd0c 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -163,6 +163,7 @@ DbWriteEngine: TypeAlias = Literal["sqlalchemy", "adbc"] DbWriteMode: TypeAlias = Literal["replace", "append", "fail"] EpochTimeUnit = Literal["ns", "us", "ms", "s", "d"] +JaxExportType: TypeAlias = Literal["array", "dict"] Orientation: TypeAlias = Literal["col", "row"] SearchSortedSide: TypeAlias = Literal["any", "left", "right"] TorchExportType: TypeAlias = Literal["tensor", "dataset", "dict"] diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 75ee3f46ff93..9ac0df0cd121 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -58,11 +58,10 @@ pydantic = ["pydantic"] pyxlsb = ["pyxlsb >= 1.0"] sqlalchemy = ["sqlalchemy", "pandas"] timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_system == 'Windows'"] -torch = ["torch"] xlsx2csv = ["xlsx2csv >= 0.8.0"] xlsxwriter = ["xlsxwriter"] all = [ - "polars[adbc,async,cloudpickle,connectorx,deltalake,fastexcel,fsspec,gevent,numpy,pandas,plot,pyarrow,pydantic,iceberg,sqlalchemy,timezone,torch,xlsx2csv,xlsxwriter]", + "polars[adbc,async,cloudpickle,connectorx,deltalake,fastexcel,fsspec,gevent,numpy,pandas,plot,pyarrow,pydantic,iceberg,sqlalchemy,timezone,xlsx2csv,xlsxwriter]", ] [tool.maturin] @@ -92,6 +91,7 @@ module = [ "fsspec.*", "gevent", "hvplot.*", + "jax.*", "kuzu", "matplotlib.*", "moto.server", diff --git a/py-polars/requirements-ci.txt b/py-polars/requirements-ci.txt index 3086002307dd..fbb39463fced 100644 --- a/py-polars/requirements-ci.txt +++ b/py-polars/requirements-ci.txt @@ -4,4 +4,6 @@ # ------------------------------------------------------- --extra-index-url https://download.pytorch.org/whl/cpu torch +jax +jaxlib pyiceberg>=0.5.0 diff --git a/py-polars/tests/docs/run_doctest.py b/py-polars/tests/docs/run_doctest.py index 7da0150e3347..39b95a548ddc 100644 --- a/py-polars/tests/docs/run_doctest.py +++ b/py-polars/tests/docs/run_doctest.py @@ -58,6 +58,7 @@ # if the module is found in the environment those doctests will # run; if the module is not found, their doctests are skipped. OPTIONAL_MODULES_AND_METHODS: dict[str, set[str]] = { + "jax": {"to_jax"}, "torch": {"to_torch"}, } OPTIONAL_MODULES: set[str] = set() diff --git a/py-polars/tests/unit/ml/__init__.py b/py-polars/tests/unit/ml/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/ml/test_to_jax.py b/py-polars/tests/unit/ml/test_to_jax.py new file mode 100644 index 000000000000..0d26ce57c7b3 --- /dev/null +++ b/py-polars/tests/unit/ml/test_to_jax.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.dependencies import _lazy_import + +# don't import jax until an actual test is triggered (the decorator already +# ensures the tests aren't run locally; this avoids premature local import) +jx, _ = _lazy_import("jax") +jxn, _ = _lazy_import("jax.numpy") + + +@pytest.fixture() +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "x": [1, 2, 2, 3], + "y": [True, False, True, False], + "z": [1.5, -0.5, 0.0, -2.0], + }, + schema_overrides={"x": pl.Int8, "z": pl.Float32}, + ) + + +@pytest.mark.ci_only() +class TestJaxIntegration: + """Test coverage for `to_jax` conversions.""" + + def assert_array_equal( + self, actual: Any, expected: Any, nans_equal: bool = True + ) -> None: + assert isinstance(actual, jx.Array) + jxn.array_equal(actual, expected, equal_nan=nans_equal) + + def test_to_jax_from_series(self) -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) + a = s.to_jax() + + assert list(a.shape) == [4] + self.assert_array_equal(a, jxn.array([1, 2, 3, 4], dtype=jxn.int8)) + + for dtype in (pl.Int32, pl.Int64, pl.UInt32, pl.UInt64): + a = s.cast(dtype).to_jax() + self.assert_array_equal(a, jxn.array([1, 2, 3, 4], dtype=jxn.int32)) + + def test_to_jax_array(self, df: pl.DataFrame) -> None: + a1 = df.to_jax() + a2 = df.to_jax("array") + a3 = df.to_jax("array", device="cpu") + a4 = df.to_jax("array", device=jx.devices("cpu")[0]) + + expected = jxn.array( + [ + [1.0, 1.0, 1.5], + [2.0, 0.0, -0.5], + [2.0, 1.0, 0.0], + [3.0, 0.0, -2.0], + ], + dtype=jxn.float32, + ) + for a in (a1, a2, a3, a4): + self.assert_array_equal(a, expected) + + def test_to_jax_dict(self, df: pl.DataFrame) -> None: + arr_dict = df.to_jax("dict") + + assert list(arr_dict.keys()) == ["x", "y", "z"] + + self.assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8)) + self.assert_array_equal( + arr_dict["y"], jxn.array([True, False, True, False], dtype=jxn.bool) + ) + self.assert_array_equal( + arr_dict["z"], jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32) + ) + + def test_to_jax_feature_label_dict(self, df: pl.DataFrame) -> None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + } + ).to_dummies("education", separator=":") + + lbl_feat_dict = df.to_jax(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] + + self.assert_array_equal( + lbl_feat_dict["label"], + jxn.array([[False], [True], [True], [False], [True]], dtype=jxn.bool), + ) + self.assert_array_equal( + lbl_feat_dict["features"], + jxn.array( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=jxn.int32, + ), + ) + + def test_misc_errors(self, df: pl.DataFrame) -> None: + with pytest.raises( + ValueError, + match="invalid `return_type`: 'stroopwafel'", + ): + _res0 = df.to_jax("stroopwafel") # type: ignore[call-overload] + + with pytest.raises( + ValueError, + match="`label` is required if setting `features` when `return_type='dict'", + ): + _res2 = df.to_jax("dict", features=cs.float()) + + with pytest.raises( + ValueError, + match="`label` and `features` only apply when `return_type` is 'dict'", + ): + _res3 = df.to_jax(label="stroopwafel") diff --git a/py-polars/tests/unit/dataframe/test_to_torch.py b/py-polars/tests/unit/ml/test_to_torch.py similarity index 83% rename from py-polars/tests/unit/dataframe/test_to_torch.py rename to py-polars/tests/unit/ml/test_to_torch.py index be8de2d1f2d1..d52c754eb83f 100644 --- a/py-polars/tests/unit/dataframe/test_to_torch.py +++ b/py-polars/tests/unit/ml/test_to_torch.py @@ -9,7 +9,7 @@ from polars.dependencies import _lazy_import # don't import torch until an actual test is triggered (the decorator already -# ensures the tests aren't run locally, this will skip premature local import) +# ensures the tests aren't run locally; this avoids premature local import) torch, _ = _lazy_import("torch") @@ -29,27 +29,25 @@ def df() -> pl.DataFrame: class TestTorchIntegration: """Test coverage for `to_torch` conversions and `polars.ml.torch` classes.""" - def assert_tensor(self, actual: Any, expected: Any) -> None: + def assert_tensor_equal(self, actual: Any, expected: Any) -> None: torch.testing.assert_close(actual, expected) - def test_to_torch_series( - self, - ) -> None: + def test_to_torch_from_series(self) -> None: s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) t = s.to_torch() assert list(t.shape) == [4] - self.assert_tensor(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) + self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) # note: torch doesn't natively support uint16/32/64. # confirm that we export to a suitable signed integer type s = s.cast(pl.UInt16) t = s.to_torch() - self.assert_tensor(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) + self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) for dtype in (pl.UInt32, pl.UInt64): t = s.cast(dtype).to_torch() - self.assert_tensor(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) + self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) def test_to_torch_tensor(self, df: pl.DataFrame) -> None: t1 = df.to_torch() @@ -63,11 +61,11 @@ def test_to_torch_dict(self, df: pl.DataFrame) -> None: assert list(td.keys()) == ["x", "y", "z"] - self.assert_tensor(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) - self.assert_tensor( + self.assert_tensor_equal(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) + self.assert_tensor_equal( td["y"], torch.tensor([True, False, True, False], dtype=torch.bool) ) - self.assert_tensor( + self.assert_tensor_equal( td["z"], torch.tensor([1.5, -0.5, 0.0, -2.0], dtype=torch.float32) ) @@ -81,11 +79,13 @@ def test_to_torch_dataset(self, df: pl.DataFrame) -> None: ts = ds[0] assert isinstance(ts, tuple) assert len(ts) == 1 - self.assert_tensor(ts[0], torch.tensor([1.0, 1.0, 1.5], dtype=torch.float64)) + self.assert_tensor_equal( + ts[0], torch.tensor([1.0, 1.0, 1.5], dtype=torch.float64) + ) def test_to_torch_dataset_feature_reorder(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset", label="x", features=["z", "y"]) - self.assert_tensor( + self.assert_tensor_equal( torch.tensor( [ [1.5000, 1.0000], @@ -96,15 +96,19 @@ def test_to_torch_dataset_feature_reorder(self, df: pl.DataFrame) -> None: ), ds.features, ) - self.assert_tensor(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) + self.assert_tensor_equal( + torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels + ) def test_to_torch_dataset_feature_subset(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset", label="x", features=["z"]) - self.assert_tensor( + self.assert_tensor_equal( torch.tensor([[1.5000], [-0.5000], [0.0000], [-2.0000]]), ds.features, ) - self.assert_tensor(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) + self.assert_tensor_equal( + torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels + ) def test_to_torch_dataset_index_slice(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset") @@ -113,11 +117,11 @@ def test_to_torch_dataset_index_slice(self, df: pl.DataFrame) -> None: expected = ( torch.tensor([[2.0000, 0.0000, -0.5000], [2.0000, 1.0000, 0.0000]]), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) ts = ds[::2] expected = (torch.tensor([[1.0000, 1.0000, 1.5000], [2.0, 1.0, 0.0]]),) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) @pytest.mark.parametrize( "index", @@ -132,7 +136,7 @@ def test_to_torch_dataset_index_multi(self, index: Any, df: pl.DataFrame) -> Non ts = ds[index] expected = (torch.tensor([[1.0, 1.0, 1.5], [3.0, 0.0, -2.0]]),) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) assert ds.schema == {"features": torch.float32, "labels": None} def test_to_torch_dataset_index_range(self, df: pl.DataFrame) -> None: @@ -142,7 +146,7 @@ def test_to_torch_dataset_index_range(self, df: pl.DataFrame) -> None: expected = ( torch.tensor([[3.0, 0.0, -2.0], [2.0, 1.0, 0.0], [2.0, 0.0, -0.5]]), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset", label="x") @@ -157,7 +161,7 @@ def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), torch.tensor([1.0, 2.0], dtype=torch.float16), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) # only apply half precision to the feature data dsf16 = ds.half(labels=False) @@ -168,7 +172,7 @@ def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), torch.tensor([1, 2], dtype=torch.int8), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) # only apply half precision to the label data dsf16 = ds.half(features=False) @@ -179,7 +183,7 @@ def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float32), torch.tensor([1.0, 2.0], dtype=torch.float16), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) # no labels dsf16 = df.to_torch("dataset").half() @@ -192,7 +196,7 @@ def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: dtype=torch.float16, ), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) @pytest.mark.parametrize( ("label", "features"), @@ -214,7 +218,7 @@ def test_to_torch_labelled_dataset( ] assert len(ts) == len(expected) for actual, exp in zip(ts, expected): - self.assert_tensor(exp, actual) + self.assert_tensor_equal(exp, actual) def test_to_torch_labelled_dataset_expr(self, df: pl.DataFrame) -> None: ds = df.to_torch( @@ -232,7 +236,7 @@ def test_to_torch_labelled_dataset_expr(self, df: pl.DataFrame) -> None: ) assert len(data) == len(expected) for actual, exp in zip(data, expected): - self.assert_tensor(exp, actual) + self.assert_tensor_equal(exp, actual) def test_to_torch_labelled_dataset_multi(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset", label=["x", "y"]) @@ -254,7 +258,7 @@ def test_to_torch_labelled_dataset_multi(self, df: pl.DataFrame) -> None: for actual, exp in zip(ts, expected): assert len(actual) == len(exp) for a, e in zip(actual, exp): - self.assert_tensor(e, a) + self.assert_tensor_equal(e, a) def test_misc_errors(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset") @@ -279,12 +283,12 @@ def test_misc_errors(self, df: pl.DataFrame) -> None: with pytest.raises( ValueError, - match="`label` and `features` parameters .* when `return_type='dataset'`", + match="`label` and `features` only apply when `return_type` is 'dataset' or 'dict'", ): _res3 = df.to_torch(label="stroopwafel") with pytest.raises( ValueError, - match="`label` and `features` parameters .* when `return_type='dataset'`", + match="`label` is required if setting `features` when `return_type='dict'", ): _res4 = df.to_torch("dict", features=cs.float()) From 4789e137a37efcac56213c4760994cf5179147ed Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 17 May 2024 18:24:31 +0400 Subject: [PATCH 02/11] minor docstring update --- py-polars/polars/dataframe/frame.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 06ee6f38f1aa..a4b47e84ced6 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1733,8 +1733,8 @@ def to_jax( [ 1. , -2.25]], dtype=float32)} As a "label" and "features" dictionary where each is designated using - a selector expression (which can also be used to cast the data if the - label and features are better-represented with different dtypes): + a col or selector expression (which can also be used to cast the data + if the label and features are better-represented with different dtypes): >>> import polars.selectors as cs >>> df.to_jax( From edd2bece304e61c1fb48bac7968ebc715ccaaadb Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 17 May 2024 18:43:23 +0400 Subject: [PATCH 03/11] skip a test under py3.8 --- py-polars/polars/dataframe/frame.py | 3 ++- py-polars/tests/unit/ml/test_to_jax.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index a4b47e84ced6..192d0e1d070b 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1673,7 +1673,8 @@ def to_jax( directs otherwise (eg: "jax_enable_x64" was set True in the config object at startup, or "JAX_ENABLE_X64" is set to "1" in the environment). order : {"c", "fortran"} - The index order of the returned Jax array, either C-like or Fortran-like. + The index order of the returned Jax array, either C-like (row-major) or + Fortran-like (column-major). See Also -------- diff --git a/py-polars/tests/unit/ml/test_to_jax.py b/py-polars/tests/unit/ml/test_to_jax.py index 0d26ce57c7b3..cf1ceb2a14d3 100644 --- a/py-polars/tests/unit/ml/test_to_jax.py +++ b/py-polars/tests/unit/ml/test_to_jax.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from typing import Any import pytest @@ -19,7 +20,7 @@ def df() -> pl.DataFrame: return pl.DataFrame( { "x": [1, 2, 2, 3], - "y": [True, False, True, False], + "y": [1, 0, 1, 0], "z": [1.5, -0.5, 0.0, -2.0], }, schema_overrides={"x": pl.Int8, "z": pl.Float32}, @@ -71,13 +72,15 @@ def test_to_jax_dict(self, df: pl.DataFrame) -> None: assert list(arr_dict.keys()) == ["x", "y", "z"] self.assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8)) - self.assert_array_equal( - arr_dict["y"], jxn.array([True, False, True, False], dtype=jxn.bool) - ) + self.assert_array_equal(arr_dict["y"], jxn.array([1, 0, 1, 0], dtype=jxn.int32)) self.assert_array_equal( arr_dict["z"], jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32) ) + @pytest.mark.skipif( + sys.version_info < (3, 9), + reason="jax.numpy.bool requires Python >= 3.9", + ) def test_to_jax_feature_label_dict(self, df: pl.DataFrame) -> None: df = pl.DataFrame( { From a75560e6c2eef04e4c5d4f6a90714caf9aa306d9 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 17 May 2024 19:22:44 +0400 Subject: [PATCH 04/11] include a `versionadded` tag --- py-polars/polars/dataframe/frame.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 192d0e1d070b..ecd9552f1d0a 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1648,6 +1648,8 @@ def to_jax( """ Convert DataFrame to a 2D Jax Array, or dict of Jax Arrays. + .. versionadded:: 0.20.27 + Parameters ---------- return_type : {"array", "dict"} From 5319780aff794087a73144763da510117a564de7 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 18 May 2024 14:59:05 +0400 Subject: [PATCH 05/11] bonus test coverage --- py-polars/polars/series/series.py | 2 +- py-polars/tests/unit/ml/test_to_jax.py | 50 ++++++++++++++++++------ py-polars/tests/unit/ml/test_to_torch.py | 32 +++++++++++++++ 3 files changed, 70 insertions(+), 14 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 858ef718f807..ec2f09d38c8d 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4504,7 +4504,7 @@ def to_jax(self, device: jax.Device | str | None = None) -> jax.Array: device = jx.devices(device)[0] if ( jx.config.jax_enable_x64 - or bool(int(os.environ.get("JAX_ENABLE_X64", "1"))) + or bool(int(os.environ.get("JAX_ENABLE_X64", "0"))) or self.dtype not in {Float64, Int64, UInt64} ): srs = self diff --git a/py-polars/tests/unit/ml/test_to_jax.py b/py-polars/tests/unit/ml/test_to_jax.py index cf1ceb2a14d3..648545e15e41 100644 --- a/py-polars/tests/unit/ml/test_to_jax.py +++ b/py-polars/tests/unit/ml/test_to_jax.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import Any +from typing import TYPE_CHECKING, Any import pytest @@ -14,6 +14,9 @@ jx, _ = _lazy_import("jax") jxn, _ = _lazy_import("jax.numpy") +if TYPE_CHECKING: + from polars.datatypes import PolarsDataType + @pytest.fixture() def df() -> pl.DataFrame: @@ -37,16 +40,30 @@ def assert_array_equal( assert isinstance(actual, jx.Array) jxn.array_equal(actual, expected, equal_nan=nans_equal) - def test_to_jax_from_series(self) -> None: - s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) - a = s.to_jax() - - assert list(a.shape) == [4] - self.assert_array_equal(a, jxn.array([1, 2, 3, 4], dtype=jxn.int8)) - - for dtype in (pl.Int32, pl.Int64, pl.UInt32, pl.UInt64): - a = s.cast(dtype).to_jax() - self.assert_array_equal(a, jxn.array([1, 2, 3, 4], dtype=jxn.int32)) + @pytest.mark.parametrize( + ("dtype", "expected_jax_dtype"), + [ + (pl.Int8, "int8"), + (pl.Int16, "int16"), + (pl.Int32, "int32"), + (pl.Int64, "int32"), + (pl.UInt8, "uint8"), + (pl.UInt16, "uint16"), + (pl.UInt32, "uint32"), + (pl.UInt64, "uint32"), + ], + ) + def test_to_jax_from_series( + self, + dtype: PolarsDataType, + expected_jax_dtype: str, + ) -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=dtype) + for dvc in (None, "cpu", jx.devices("cpu")[0]): + self.assert_array_equal( + s.to_jax(device=dvc), + jxn.array([1, 2, 3, 4], dtype=getattr(jxn, expected_jax_dtype)), + ) def test_to_jax_array(self, df: pl.DataFrame) -> None: a1 = df.to_jax() @@ -68,15 +85,22 @@ def test_to_jax_array(self, df: pl.DataFrame) -> None: def test_to_jax_dict(self, df: pl.DataFrame) -> None: arr_dict = df.to_jax("dict") - assert list(arr_dict.keys()) == ["x", "y", "z"] self.assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8)) self.assert_array_equal(arr_dict["y"], jxn.array([1, 0, 1, 0], dtype=jxn.int32)) self.assert_array_equal( - arr_dict["z"], jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32) + arr_dict["z"], + jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32), ) + arr_dict = df.to_jax("dict", dtype=pl.Float32) + for a, expected_data in zip( + arr_dict.values(), + ([1.0, 2.0, 2.0, 3.0], [1.0, 0.0, 1.0, 0.0], [1.5, -0.5, 0.0, -2.0]), + ): + self.assert_array_equal(a, jxn.array(expected_data, dtype=jxn.float32)) + @pytest.mark.skipif( sys.version_info < (3, 9), reason="jax.numpy.bool requires Python >= 3.9", diff --git a/py-polars/tests/unit/ml/test_to_torch.py b/py-polars/tests/unit/ml/test_to_torch.py index d52c754eb83f..3cfbe2bc01d8 100644 --- a/py-polars/tests/unit/ml/test_to_torch.py +++ b/py-polars/tests/unit/ml/test_to_torch.py @@ -69,6 +69,38 @@ def test_to_torch_dict(self, df: pl.DataFrame) -> None: td["z"], torch.tensor([1.5, -0.5, 0.0, -2.0], dtype=torch.float32) ) + def test_to_torch_feature_label_dict(self, df: pl.DataFrame) -> None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + }, + schema_overrides={"age": pl.Int32, "income": pl.Int32}, + ).to_dummies("education", separator=":") + + lbl_feat_dict = df.to_torch(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] + + self.assert_tensor_equal( + lbl_feat_dict["label"], + torch.tensor([[False], [True], [True], [False], [True]], dtype=torch.bool), + ) + self.assert_tensor_equal( + lbl_feat_dict["features"], + torch.tensor( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=torch.int32, + ), + ) + def test_to_torch_dataset(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset", dtype=pl.Float64) From 7bf53ddf86b86afd4b5e548e160e47c285700026 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 18 May 2024 16:06:12 +0400 Subject: [PATCH 06/11] fix default `JAX_ENABLE_X64` env value --- py-polars/polars/dataframe/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index ecd9552f1d0a..257d806b258c 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1767,7 +1767,7 @@ def to_jax( "for specific installation recommendations for the Jax package", ) enabled_double_precision = jx.config.jax_enable_x64 or bool( - int(os.environ.get("JAX_ENABLE_X64", "1")) + int(os.environ.get("JAX_ENABLE_X64", "0")) ) if dtype: frame = self.cast(dtype) From d9f138601f77e1d955dcc7cf8d7fc0262191714c Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 19 May 2024 18:04:11 +0400 Subject: [PATCH 07/11] make use of `pytestmark` --- py-polars/tests/unit/ml/test_to_jax.py | 230 +++++----- py-polars/tests/unit/ml/test_to_torch.py | 525 +++++++++++------------ 2 files changed, 377 insertions(+), 378 deletions(-) diff --git a/py-polars/tests/unit/ml/test_to_jax.py b/py-polars/tests/unit/ml/test_to_jax.py index 648545e15e41..5dc0c172f084 100644 --- a/py-polars/tests/unit/ml/test_to_jax.py +++ b/py-polars/tests/unit/ml/test_to_jax.py @@ -14,6 +14,8 @@ jx, _ = _lazy_import("jax") jxn, _ = _lazy_import("jax.numpy") +pytestmark = pytest.mark.ci_only + if TYPE_CHECKING: from polars.datatypes import PolarsDataType @@ -30,127 +32,125 @@ def df() -> pl.DataFrame: ) -@pytest.mark.ci_only() -class TestJaxIntegration: - """Test coverage for `to_jax` conversions.""" +def assert_array_equal(actual: Any, expected: Any, nans_equal: bool = True) -> None: + assert isinstance(actual, jx.Array) + jxn.array_equal(actual, expected, equal_nan=nans_equal) + + +@pytest.mark.parametrize( + ("dtype", "expected_jax_dtype"), + [ + (pl.Int8, "int8"), + (pl.Int16, "int16"), + (pl.Int32, "int32"), + (pl.Int64, "int32"), + (pl.UInt8, "uint8"), + (pl.UInt16, "uint16"), + (pl.UInt32, "uint32"), + (pl.UInt64, "uint32"), + ], +) +def test_to_jax_from_series( + dtype: PolarsDataType, + expected_jax_dtype: str, +) -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=dtype) + for dvc in (None, "cpu", jx.devices("cpu")[0]): + assert_array_equal( + s.to_jax(device=dvc), + jxn.array([1, 2, 3, 4], dtype=getattr(jxn, expected_jax_dtype)), + ) + - def assert_array_equal( - self, actual: Any, expected: Any, nans_equal: bool = True - ) -> None: - assert isinstance(actual, jx.Array) - jxn.array_equal(actual, expected, equal_nan=nans_equal) +def test_to_jax_array(df: pl.DataFrame) -> None: + a1 = df.to_jax() + a2 = df.to_jax("array") + a3 = df.to_jax("array", device="cpu") + a4 = df.to_jax("array", device=jx.devices("cpu")[0]) - @pytest.mark.parametrize( - ("dtype", "expected_jax_dtype"), + expected = jxn.array( [ - (pl.Int8, "int8"), - (pl.Int16, "int16"), - (pl.Int32, "int32"), - (pl.Int64, "int32"), - (pl.UInt8, "uint8"), - (pl.UInt16, "uint16"), - (pl.UInt32, "uint32"), - (pl.UInt64, "uint32"), + [1.0, 1.0, 1.5], + [2.0, 0.0, -0.5], + [2.0, 1.0, 0.0], + [3.0, 0.0, -2.0], ], + dtype=jxn.float32, ) - def test_to_jax_from_series( - self, - dtype: PolarsDataType, - expected_jax_dtype: str, - ) -> None: - s = pl.Series("x", [1, 2, 3, 4], dtype=dtype) - for dvc in (None, "cpu", jx.devices("cpu")[0]): - self.assert_array_equal( - s.to_jax(device=dvc), - jxn.array([1, 2, 3, 4], dtype=getattr(jxn, expected_jax_dtype)), - ) - - def test_to_jax_array(self, df: pl.DataFrame) -> None: - a1 = df.to_jax() - a2 = df.to_jax("array") - a3 = df.to_jax("array", device="cpu") - a4 = df.to_jax("array", device=jx.devices("cpu")[0]) - - expected = jxn.array( - [ - [1.0, 1.0, 1.5], - [2.0, 0.0, -0.5], - [2.0, 1.0, 0.0], - [3.0, 0.0, -2.0], - ], - dtype=jxn.float32, - ) - for a in (a1, a2, a3, a4): - self.assert_array_equal(a, expected) - - def test_to_jax_dict(self, df: pl.DataFrame) -> None: - arr_dict = df.to_jax("dict") - assert list(arr_dict.keys()) == ["x", "y", "z"] - - self.assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8)) - self.assert_array_equal(arr_dict["y"], jxn.array([1, 0, 1, 0], dtype=jxn.int32)) - self.assert_array_equal( - arr_dict["z"], - jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32), - ) + for a in (a1, a2, a3, a4): + assert_array_equal(a, expected) + - arr_dict = df.to_jax("dict", dtype=pl.Float32) - for a, expected_data in zip( - arr_dict.values(), - ([1.0, 2.0, 2.0, 3.0], [1.0, 0.0, 1.0, 0.0], [1.5, -0.5, 0.0, -2.0]), - ): - self.assert_array_equal(a, jxn.array(expected_data, dtype=jxn.float32)) +def test_to_jax_dict(df: pl.DataFrame) -> None: + arr_dict = df.to_jax("dict") + assert list(arr_dict.keys()) == ["x", "y", "z"] - @pytest.mark.skipif( - sys.version_info < (3, 9), - reason="jax.numpy.bool requires Python >= 3.9", + assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8)) + assert_array_equal(arr_dict["y"], jxn.array([1, 0, 1, 0], dtype=jxn.int32)) + assert_array_equal( + arr_dict["z"], + jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32), ) - def test_to_jax_feature_label_dict(self, df: pl.DataFrame) -> None: - df = pl.DataFrame( - { - "age": [25, 32, 45, 22, 34], - "income": [50000, 75000, 60000, 58000, 120000], - "education": ["bachelor", "master", "phd", "bachelor", "phd"], - "purchased": [False, True, True, False, True], - } - ).to_dummies("education", separator=":") - - lbl_feat_dict = df.to_jax(return_type="dict", label="purchased") - assert list(lbl_feat_dict.keys()) == ["label", "features"] - - self.assert_array_equal( - lbl_feat_dict["label"], - jxn.array([[False], [True], [True], [False], [True]], dtype=jxn.bool), - ) - self.assert_array_equal( - lbl_feat_dict["features"], - jxn.array( - [ - [25, 50000, 1, 0, 0], - [32, 75000, 0, 1, 0], - [45, 60000, 0, 0, 1], - [22, 58000, 1, 0, 0], - [34, 120000, 0, 0, 1], - ], - dtype=jxn.int32, - ), - ) - def test_misc_errors(self, df: pl.DataFrame) -> None: - with pytest.raises( - ValueError, - match="invalid `return_type`: 'stroopwafel'", - ): - _res0 = df.to_jax("stroopwafel") # type: ignore[call-overload] - - with pytest.raises( - ValueError, - match="`label` is required if setting `features` when `return_type='dict'", - ): - _res2 = df.to_jax("dict", features=cs.float()) - - with pytest.raises( - ValueError, - match="`label` and `features` only apply when `return_type` is 'dict'", - ): - _res3 = df.to_jax(label="stroopwafel") + arr_dict = df.to_jax("dict", dtype=pl.Float32) + for a, expected_data in zip( + arr_dict.values(), + ([1.0, 2.0, 2.0, 3.0], [1.0, 0.0, 1.0, 0.0], [1.5, -0.5, 0.0, -2.0]), + ): + assert_array_equal(a, jxn.array(expected_data, dtype=jxn.float32)) + + +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="jax.numpy.bool requires Python >= 3.9", +) +def test_to_jax_feature_label_dict(df: pl.DataFrame) -> None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + } + ).to_dummies("education", separator=":") + + lbl_feat_dict = df.to_jax(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] + + assert_array_equal( + lbl_feat_dict["label"], + jxn.array([[False], [True], [True], [False], [True]], dtype=jxn.bool), + ) + assert_array_equal( + lbl_feat_dict["features"], + jxn.array( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=jxn.int32, + ), + ) + + +def test_misc_errors(df: pl.DataFrame) -> None: + with pytest.raises( + ValueError, + match="invalid `return_type`: 'stroopwafel'", + ): + _res0 = df.to_jax("stroopwafel") # type: ignore[call-overload] + + with pytest.raises( + ValueError, + match="`label` is required if setting `features` when `return_type='dict'", + ): + _res2 = df.to_jax("dict", features=cs.float()) + + with pytest.raises( + ValueError, + match="`label` and `features` only apply when `return_type` is 'dict'", + ): + _res3 = df.to_jax(label="stroopwafel") diff --git a/py-polars/tests/unit/ml/test_to_torch.py b/py-polars/tests/unit/ml/test_to_torch.py index 3cfbe2bc01d8..7f1a4711c8ac 100644 --- a/py-polars/tests/unit/ml/test_to_torch.py +++ b/py-polars/tests/unit/ml/test_to_torch.py @@ -12,6 +12,8 @@ # ensures the tests aren't run locally; this avoids premature local import) torch, _ = _lazy_import("torch") +pytestmark = pytest.mark.ci_only + @pytest.fixture() def df() -> pl.DataFrame: @@ -25,302 +27,299 @@ def df() -> pl.DataFrame: ) -@pytest.mark.ci_only() -class TestTorchIntegration: - """Test coverage for `to_torch` conversions and `polars.ml.torch` classes.""" +def assert_tensor_equal(actual: Any, expected: Any) -> None: + torch.testing.assert_close(actual, expected) - def assert_tensor_equal(self, actual: Any, expected: Any) -> None: - torch.testing.assert_close(actual, expected) - def test_to_torch_from_series(self) -> None: - s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) - t = s.to_torch() +def test_to_torch_from_series() -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) + t = s.to_torch() - assert list(t.shape) == [4] - self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) + assert list(t.shape) == [4] + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) - # note: torch doesn't natively support uint16/32/64. - # confirm that we export to a suitable signed integer type - s = s.cast(pl.UInt16) - t = s.to_torch() - self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) + # note: torch doesn't natively support uint16/32/64. + # confirm that we export to a suitable signed integer type + s = s.cast(pl.UInt16) + t = s.to_torch() + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) - for dtype in (pl.UInt32, pl.UInt64): - t = s.cast(dtype).to_torch() - self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) + for dtype in (pl.UInt32, pl.UInt64): + t = s.cast(dtype).to_torch() + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) - def test_to_torch_tensor(self, df: pl.DataFrame) -> None: - t1 = df.to_torch() - t2 = df.to_torch("tensor") - assert list(t1.shape) == [4, 3] - assert (t1 == t2).all().item() is True +def test_to_torch_tensor(df: pl.DataFrame) -> None: + t1 = df.to_torch() + t2 = df.to_torch("tensor") - def test_to_torch_dict(self, df: pl.DataFrame) -> None: - td = df.to_torch("dict") + assert list(t1.shape) == [4, 3] + assert (t1 == t2).all().item() is True - assert list(td.keys()) == ["x", "y", "z"] - self.assert_tensor_equal(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) - self.assert_tensor_equal( - td["y"], torch.tensor([True, False, True, False], dtype=torch.bool) - ) - self.assert_tensor_equal( - td["z"], torch.tensor([1.5, -0.5, 0.0, -2.0], dtype=torch.float32) - ) +def test_to_torch_dict(df: pl.DataFrame) -> None: + td = df.to_torch("dict") - def test_to_torch_feature_label_dict(self, df: pl.DataFrame) -> None: - df = pl.DataFrame( - { - "age": [25, 32, 45, 22, 34], - "income": [50000, 75000, 60000, 58000, 120000], - "education": ["bachelor", "master", "phd", "bachelor", "phd"], - "purchased": [False, True, True, False, True], - }, - schema_overrides={"age": pl.Int32, "income": pl.Int32}, - ).to_dummies("education", separator=":") - - lbl_feat_dict = df.to_torch(return_type="dict", label="purchased") - assert list(lbl_feat_dict.keys()) == ["label", "features"] - - self.assert_tensor_equal( - lbl_feat_dict["label"], - torch.tensor([[False], [True], [True], [False], [True]], dtype=torch.bool), - ) - self.assert_tensor_equal( - lbl_feat_dict["features"], - torch.tensor( - [ - [25, 50000, 1, 0, 0], - [32, 75000, 0, 1, 0], - [45, 60000, 0, 0, 1], - [22, 58000, 1, 0, 0], - [34, 120000, 0, 0, 1], - ], - dtype=torch.int32, - ), - ) + assert list(td.keys()) == ["x", "y", "z"] - def test_to_torch_dataset(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", dtype=pl.Float64) + assert_tensor_equal(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) + assert_tensor_equal( + td["y"], torch.tensor([True, False, True, False], dtype=torch.bool) + ) + assert_tensor_equal( + td["z"], torch.tensor([1.5, -0.5, 0.0, -2.0], dtype=torch.float32) + ) - assert len(ds) == 4 - assert isinstance(ds, torch.utils.data.Dataset) - assert repr(ds).startswith(" None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + }, + schema_overrides={"age": pl.Int32, "income": pl.Int32}, + ).to_dummies("education", separator=":") - def test_to_torch_dataset_feature_reorder(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label="x", features=["z", "y"]) - self.assert_tensor_equal( - torch.tensor( - [ - [1.5000, 1.0000], - [-0.5000, 0.0000], - [0.0000, 1.0000], - [-2.0000, 0.0000], - ] - ), - ds.features, - ) - self.assert_tensor_equal( - torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels - ) + lbl_feat_dict = df.to_torch(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] - def test_to_torch_dataset_feature_subset(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label="x", features=["z"]) - self.assert_tensor_equal( - torch.tensor([[1.5000], [-0.5000], [0.0000], [-2.0000]]), - ds.features, - ) - self.assert_tensor_equal( - torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels - ) + assert_tensor_equal( + lbl_feat_dict["label"], + torch.tensor([[False], [True], [True], [False], [True]], dtype=torch.bool), + ) + assert_tensor_equal( + lbl_feat_dict["features"], + torch.tensor( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=torch.int32, + ), + ) - def test_to_torch_dataset_index_slice(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - ts = ds[1:3] - expected = ( - torch.tensor([[2.0000, 0.0000, -0.5000], [2.0000, 1.0000, 0.0000]]), - ) - self.assert_tensor_equal(expected, ts) +def test_to_torch_dataset(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", dtype=pl.Float64) - ts = ds[::2] - expected = (torch.tensor([[1.0000, 1.0000, 1.5000], [2.0, 1.0, 0.0]]),) - self.assert_tensor_equal(expected, ts) + assert len(ds) == 4 + assert isinstance(ds, torch.utils.data.Dataset) + assert repr(ds).startswith(" None: + ds = df.to_torch("dataset", label="x", features=["z", "y"]) + assert_tensor_equal( + torch.tensor( + [ + [1.5000, 1.0000], + [-0.5000, 0.0000], + [0.0000, 1.0000], + [-2.0000, 0.0000], + ] + ), + ds.features, ) - def test_to_torch_dataset_index_multi(self, index: Any, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - ts = ds[index] + assert_tensor_equal(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) - expected = (torch.tensor([[1.0, 1.0, 1.5], [3.0, 0.0, -2.0]]),) - self.assert_tensor_equal(expected, ts) - assert ds.schema == {"features": torch.float32, "labels": None} - def test_to_torch_dataset_index_range(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - ts = ds[range(3, 0, -1)] +def test_to_torch_dataset_feature_subset(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label="x", features=["z"]) + assert_tensor_equal( + torch.tensor([[1.5000], [-0.5000], [0.0000], [-2.0000]]), + ds.features, + ) + assert_tensor_equal(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) - expected = ( - torch.tensor([[3.0, 0.0, -2.0], [2.0, 1.0, 0.0], [2.0, 0.0, -0.5]]), - ) - self.assert_tensor_equal(expected, ts) - def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label="x") - assert ds.schema == {"features": torch.float32, "labels": torch.int8} +def test_to_torch_dataset_index_slice(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[1:3] - dsf16 = ds.half() - assert dsf16.schema == {"features": torch.float16, "labels": torch.float16} + expected = (torch.tensor([[2.0000, 0.0000, -0.5000], [2.0000, 1.0000, 0.0000]]),) + assert_tensor_equal(expected, ts) - # half precision across all data - ts = dsf16[:3:2] - expected = ( - torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), - torch.tensor([1.0, 2.0], dtype=torch.float16), - ) - self.assert_tensor_equal(expected, ts) + ts = ds[::2] + expected = (torch.tensor([[1.0000, 1.0000, 1.5000], [2.0, 1.0, 0.0]]),) + assert_tensor_equal(expected, ts) - # only apply half precision to the feature data - dsf16 = ds.half(labels=False) - assert dsf16.schema == {"features": torch.float16, "labels": torch.int8} - ts = dsf16[:3:2] - expected = ( - torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), - torch.tensor([1, 2], dtype=torch.int8), - ) - self.assert_tensor_equal(expected, ts) +@pytest.mark.parametrize( + "index", + [ + [0, 3], + range(0, 4, 3), + slice(0, 4, 3), + ], +) +def test_to_torch_dataset_index_multi(index: Any, df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[index] + + expected = (torch.tensor([[1.0, 1.0, 1.5], [3.0, 0.0, -2.0]]),) + assert_tensor_equal(expected, ts) + assert ds.schema == {"features": torch.float32, "labels": None} + + +def test_to_torch_dataset_index_range(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[range(3, 0, -1)] + + expected = (torch.tensor([[3.0, 0.0, -2.0], [2.0, 1.0, 0.0], [2.0, 0.0, -0.5]]),) + assert_tensor_equal(expected, ts) + + +def test_to_dataset_half_precision(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label="x") + assert ds.schema == {"features": torch.float32, "labels": torch.int8} + + dsf16 = ds.half() + assert dsf16.schema == {"features": torch.float16, "labels": torch.float16} + + # half precision across all data + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), + torch.tensor([1.0, 2.0], dtype=torch.float16), + ) + assert_tensor_equal(expected, ts) + + # only apply half precision to the feature data + dsf16 = ds.half(labels=False) + assert dsf16.schema == {"features": torch.float16, "labels": torch.int8} + + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), + torch.tensor([1, 2], dtype=torch.int8), + ) + assert_tensor_equal(expected, ts) - # only apply half precision to the label data - dsf16 = ds.half(features=False) - assert dsf16.schema == {"features": torch.float32, "labels": torch.float16} + # only apply half precision to the label data + dsf16 = ds.half(features=False) + assert dsf16.schema == {"features": torch.float32, "labels": torch.float16} - ts = dsf16[:3:2] + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float32), + torch.tensor([1.0, 2.0], dtype=torch.float16), + ) + assert_tensor_equal(expected, ts) + + # no labels + dsf16 = df.to_torch("dataset").half() + assert dsf16.schema == {"features": torch.float16, "labels": None} + + ts = dsf16[:3:2] + expected = ( # type: ignore[assignment] + torch.tensor( + data=[[1.0000, 1.0000, 1.5000], [2.0000, 1.0000, 0.0000]], + dtype=torch.float16, + ), + ) + assert_tensor_equal(expected, ts) + + +@pytest.mark.parametrize( + ("label", "features"), + [ + ("x", None), + ("x", ["y", "z"]), + (cs.by_dtype(pl.INTEGER_DTYPES), ~cs.by_dtype(pl.INTEGER_DTYPES)), + ], +) +def test_to_torch_labelled_dataset(label: Any, features: Any, df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label=label, features=features) + ts = next(iter(torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False))) + + expected = [ + torch.tensor([[1.0, 1.5], [0.0, -0.5]]), + torch.tensor([1, 2], dtype=torch.int8), + ] + assert len(ts) == len(expected) + for actual, exp in zip(ts, expected): + assert_tensor_equal(exp, actual) + + +def test_to_torch_labelled_dataset_expr(df: pl.DataFrame) -> None: + ds = df.to_torch( + "dataset", + dtype=pl.Float64, + label=(pl.col("x") * 8).cast(pl.Int16), + ) + dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False) + for data in (tuple(ds[:2]), tuple(next(iter(dl)))): expected = ( - torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float32), - torch.tensor([1.0, 2.0], dtype=torch.float16), - ) - self.assert_tensor_equal(expected, ts) - - # no labels - dsf16 = df.to_torch("dataset").half() - assert dsf16.schema == {"features": torch.float16, "labels": None} - - ts = dsf16[:3:2] - expected = ( # type: ignore[assignment] - torch.tensor( - data=[[1.0000, 1.0000, 1.5000], [2.0000, 1.0000, 0.0000]], - dtype=torch.float16, - ), + torch.tensor([[1.0000, 1.5000], [0.0000, -0.5000]], dtype=torch.float64), + torch.tensor([8, 16], dtype=torch.int16), ) - self.assert_tensor_equal(expected, ts) + assert len(data) == len(expected) + for actual, exp in zip(data, expected): + assert_tensor_equal(exp, actual) - @pytest.mark.parametrize( - ("label", "features"), + +def test_to_torch_labelled_dataset_multi(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label=["x", "y"]) + dl = torch.utils.data.DataLoader(ds, batch_size=3, shuffle=False) + ts = list(dl) + + expected = [ [ - ("x", None), - ("x", ["y", "z"]), - (cs.by_dtype(pl.INTEGER_DTYPES), ~cs.by_dtype(pl.INTEGER_DTYPES)), + torch.tensor([[1.5000], [-0.5000], [0.0000]]), + torch.tensor([[1, 1], [2, 0], [2, 1]], dtype=torch.int8), ], - ) - def test_to_torch_labelled_dataset( - self, label: Any, features: Any, df: pl.DataFrame - ) -> None: - ds = df.to_torch("dataset", label=label, features=features) - ts = next(iter(torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False))) - - expected = [ - torch.tensor([[1.0, 1.5], [0.0, -0.5]]), - torch.tensor([1, 2], dtype=torch.int8), - ] - assert len(ts) == len(expected) - for actual, exp in zip(ts, expected): - self.assert_tensor_equal(exp, actual) - - def test_to_torch_labelled_dataset_expr(self, df: pl.DataFrame) -> None: - ds = df.to_torch( - "dataset", - dtype=pl.Float64, - label=(pl.col("x") * 8).cast(pl.Int16), - ) - dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False) - for data in (tuple(ds[:2]), tuple(next(iter(dl)))): - expected = ( - torch.tensor( - [[1.0000, 1.5000], [0.0000, -0.5000]], dtype=torch.float64 - ), - torch.tensor([8, 16], dtype=torch.int16), - ) - assert len(data) == len(expected) - for actual, exp in zip(data, expected): - self.assert_tensor_equal(exp, actual) - - def test_to_torch_labelled_dataset_multi(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label=["x", "y"]) - dl = torch.utils.data.DataLoader(ds, batch_size=3, shuffle=False) - ts = list(dl) - - expected = [ - [ - torch.tensor([[1.5000], [-0.5000], [0.0000]]), - torch.tensor([[1, 1], [2, 0], [2, 1]], dtype=torch.int8), - ], - [ - torch.tensor([[-2.0]]), - torch.tensor([[3, 0]], dtype=torch.int8), - ], - ] - assert len(ts) == len(expected) - - for actual, exp in zip(ts, expected): - assert len(actual) == len(exp) - for a, e in zip(actual, exp): - self.assert_tensor_equal(e, a) - - def test_misc_errors(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - - with pytest.raises( - ValueError, - match="invalid `return_type`: 'stroopwafel'", - ): - _res0 = df.to_torch("stroopwafel") # type: ignore[call-overload] - - with pytest.raises( - ValueError, - match="does not support u16, u32, or u64 dtypes", - ): - _res1 = df.to_torch(dtype=pl.UInt16) - - with pytest.raises( - IndexError, - match="tensors used as indices must be long, int", - ): - _res2 = ds[torch.tensor([0, 3], dtype=torch.complex64)] - - with pytest.raises( - ValueError, - match="`label` and `features` only apply when `return_type` is 'dataset' or 'dict'", - ): - _res3 = df.to_torch(label="stroopwafel") - - with pytest.raises( - ValueError, - match="`label` is required if setting `features` when `return_type='dict'", - ): - _res4 = df.to_torch("dict", features=cs.float()) + [ + torch.tensor([[-2.0]]), + torch.tensor([[3, 0]], dtype=torch.int8), + ], + ] + assert len(ts) == len(expected) + + for actual, exp in zip(ts, expected): + assert len(actual) == len(exp) + for a, e in zip(actual, exp): + assert_tensor_equal(e, a) + + +def test_misc_errors(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + + with pytest.raises( + ValueError, + match="invalid `return_type`: 'stroopwafel'", + ): + _res0 = df.to_torch("stroopwafel") # type: ignore[call-overload] + + with pytest.raises( + ValueError, + match="does not support u16, u32, or u64 dtypes", + ): + _res1 = df.to_torch(dtype=pl.UInt16) + + with pytest.raises( + IndexError, + match="tensors used as indices must be long, int", + ): + _res2 = ds[torch.tensor([0, 3], dtype=torch.complex64)] + + with pytest.raises( + ValueError, + match="`label` and `features` only apply when `return_type` is 'dataset' or 'dict'", + ): + _res3 = df.to_torch(label="stroopwafel") + + with pytest.raises( + ValueError, + match="`label` is required if setting `features` when `return_type='dict'", + ): + _res4 = df.to_torch("dict", features=cs.float()) From 6eef9b9f2a8c27ccac61cd06d35a7656e2414210 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 19 May 2024 19:02:09 +0400 Subject: [PATCH 08/11] mark as `unstable` --- py-polars/polars/dataframe/frame.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 257d806b258c..66e5f072fc46 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1635,6 +1635,7 @@ def to_jax( order: IndexOrder = ..., ) -> dict[str, jax.Array]: ... + @unstable() def to_jax( self, return_type: JaxExportType = "array", @@ -1650,6 +1651,10 @@ def to_jax( .. versionadded:: 0.20.27 + .. warning:: + This functionality is currently considered **unstable**. It may be + changed at any point without it being considered a breaking change. + Parameters ---------- return_type : {"array", "dict"} @@ -1838,6 +1843,7 @@ def to_torch( dtype: PolarsDataType | None = ..., ) -> dict[str, torch.Tensor]: ... + @unstable() def to_torch( self, return_type: TorchExportType = "tensor", @@ -1851,6 +1857,10 @@ def to_torch( .. versionadded:: 0.20.23 + .. warning:: + This functionality is currently considered **unstable**. It may be + changed at any point without it being considered a breaking change. + Parameters ---------- return_type : {"tensor", "dataset", "dict"} From 1e417600972a284ae7ff421f302efe31fee7a389 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 19 May 2024 20:59:01 +0400 Subject: [PATCH 09/11] add `to_jax` entry to the sphinx docs --- py-polars/docs/source/reference/dataframe/export.rst | 1 + py-polars/polars/dataframe/frame.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/py-polars/docs/source/reference/dataframe/export.rst b/py-polars/docs/source/reference/dataframe/export.rst index 12cb378dc6ef..0347b7429da0 100644 --- a/py-polars/docs/source/reference/dataframe/export.rst +++ b/py-polars/docs/source/reference/dataframe/export.rst @@ -13,6 +13,7 @@ Export DataFrame data to other formats: DataFrame.to_dict DataFrame.to_dicts DataFrame.to_init_repr + DataFrame.to_jax DataFrame.to_numpy DataFrame.to_pandas DataFrame.to_struct diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 66e5f072fc46..085ed0d259e0 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1647,7 +1647,7 @@ def to_jax( order: IndexOrder = "fortran", ) -> jax.Array | dict[str, jax.Array]: """ - Convert DataFrame to a 2D Jax Array, or dict of Jax Arrays. + Convert DataFrame to a Jax Array, or dict of Jax Arrays. .. versionadded:: 0.20.27 @@ -1658,7 +1658,7 @@ def to_jax( Parameters ---------- return_type : {"array", "dict"} - Set return type; a 2D Jax Array, or dict of Jax Arrays. + Set return type; a Jax Array, or dict of Jax Arrays. device Specify the jax `Device` on which the array will be created; can provide a string (such as "cpu", "gpu", or "tpu") in which case the device is @@ -1853,7 +1853,7 @@ def to_torch( dtype: PolarsDataType | None = None, ) -> torch.Tensor | dict[str, torch.Tensor] | PolarsDataset: """ - Convert DataFrame to a 2D PyTorch Tensor, Dataset, or dict of Tensors. + Convert DataFrame to a PyTorch Tensor, Dataset, or dict of Tensors. .. versionadded:: 0.20.23 @@ -1864,7 +1864,7 @@ def to_torch( Parameters ---------- return_type : {"tensor", "dataset", "dict"} - Set return type; a 2D PyTorch Tensor, PolarsDataset (a frame-specialized + Set return type; a PyTorch Tensor, PolarsDataset (a frame-specialized TensorDataset), or dict of Tensors. label One or more column names, expressions, or selectors that label the feature From a68782a1b5fbce386db3617c6b1654473df642aa Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 19 May 2024 22:40:30 +0400 Subject: [PATCH 10/11] add `series` sphinx docs entry and mark `unstable` --- .../src/chunked_array/ops/aggregate/mod.rs | 2 +- py-polars/docs/source/reference/series/export.rst | 2 ++ py-polars/polars/series/series.py | 14 ++++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index 7fe43517b365..5e4b09e5d18a 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -195,7 +195,7 @@ where } } -/// Booleans are casted to 1 or 0. +/// Booleans are cast to 1 or 0. impl BooleanChunked { pub fn sum(&self) -> Option { Some(if self.is_empty() { diff --git a/py-polars/docs/source/reference/series/export.rst b/py-polars/docs/source/reference/series/export.rst index 6e19c4efa4f7..c1c7bacf8086 100644 --- a/py-polars/docs/source/reference/series/export.rst +++ b/py-polars/docs/source/reference/series/export.rst @@ -10,7 +10,9 @@ Export Series data to other formats: Series.to_arrow Series.to_frame + Series.to_jax Series.to_list Series.to_numpy Series.to_pandas Series.to_init_repr + Series.to_torch diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index ec2f09d38c8d..aa2e9f088098 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4476,10 +4476,17 @@ def to_numpy( return self._s.to_numpy(allow_copy=allow_copy, writable=writable) + @unstable() def to_jax(self, device: jax.Device | str | None = None) -> jax.Array: """ Convert this Series to a Jax Array. + .. versionadded:: 0.20.27 + + .. warning:: + This functionality is currently considered **unstable**. It may be + changed at any point without it being considered a breaking change. + Parameters ---------- device @@ -4519,10 +4526,17 @@ def to_jax(self, device: jax.Device | str | None = None) -> jax.Array: order="K", ) + @unstable() def to_torch(self) -> torch.Tensor: """ Convert this Series to a PyTorch Tensor. + .. versionadded:: 0.20.23 + + .. warning:: + This functionality is currently considered **unstable**. It may be + changed at any point without it being considered a breaking change. + Examples -------- >>> s = pl.Series("x", [1, 0, 1, 2, 0], dtype=pl.UInt8) From 362b8d9be80e2a5cce9f8b3c0b3069980ffae15c Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Mon, 20 May 2024 00:25:33 +0400 Subject: [PATCH 11/11] add a docstring note to series `to_torch` --- py-polars/polars/series/series.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index aa2e9f088098..66f150957584 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4537,11 +4537,19 @@ def to_torch(self) -> torch.Tensor: This functionality is currently considered **unstable**. It may be changed at any point without it being considered a breaking change. + Notes + ----- + PyTorch tensors do not support UInt16, UInt32, or UInt64; these dtypes + will be automatically cast to Int32, Int64, and Int64, respectively. + Examples -------- >>> s = pl.Series("x", [1, 0, 1, 2, 0], dtype=pl.UInt8) >>> s.to_torch() tensor([1, 0, 1, 2, 0], dtype=torch.uint8) + >>> s = pl.Series("x", [5.5, -10.0, 2.5], dtype=pl.Float32) + >>> s.to_torch() + tensor([ 5.5000, -10.0000, 2.5000]) """ torch = import_optional("torch")