Skip to content

Commit

Permalink
fix(python): Fix parsing of shape in Array constructor and deprec…
Browse files Browse the repository at this point in the history
…ate `width` parameter (#16567)
  • Loading branch information
stinodego authored May 30, 2024
1 parent aa771aa commit f1be8d9
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 42 deletions.
59 changes: 41 additions & 18 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,30 +736,45 @@ class Array(NestedType):

inner: PolarsDataType | None = None
size: int
# outer shape
shape: None | tuple[int, ...] = None
shape: tuple[int, ...]

def __init__(
self,
inner: PolarsDataType | PythonDataType,
shape: int | tuple[int, ...] | None = None,
*,
width: int | None = None,
):
if width is not None:
from polars._utils.deprecation import issue_deprecation_warning

issue_deprecation_warning(
"The `width` parameter for `Array` is deprecated. Use `shape` instead.",
version="0.20.31",
)
shape = width
elif shape is None:
msg = "Array constructor is missing the required argument `shape`"
raise TypeError(msg)

inner_parsed = polars.datatypes.py_type_to_dtype(inner)

if isinstance(shape, int):
self.inner = inner_parsed
self.size = shape
self.shape = (shape,)

elif isinstance(shape, tuple):
if len(shape) > 1:
self.shape = shape
for dim in shape[1:]:
inner = Array(inner, dim)
shape = shape[0]
inner_parsed = Array(inner_parsed, shape[1:])

if shape is None:
msg = "either 'shape' or 'width' must be set"
raise ValueError(msg)
self.inner = inner_parsed
self.size = shape[0]
self.shape = shape

self.inner = polars.datatypes.py_type_to_dtype(inner)
self.size = shape
else:
msg = f"invalid input for shape: {shape!r}"
raise TypeError(msg)

def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
# This equality check allows comparison of type classes and type instances.
Expand All @@ -785,16 +800,24 @@ def __hash__(self) -> int:
return hash((self.__class__, self.inner, self.size))

def __repr__(self) -> str:
# Get leaf type
dtype = self.inner
while isinstance(dtype, Array):
dtype = dtype.inner

class_name = self.__class__.__name__
return f"{class_name}({dtype!r}, shape={self.shape})"

if self.shape:
# get leaf type
dtype = self.inner
while isinstance(dtype, Array):
dtype = dtype.inner
@property
def width(self) -> int:
"""The size of the Array."""
from polars._utils.deprecation import issue_deprecation_warning

return f"{class_name}({dtype!r}, shape={self.shape})"
return f"{class_name}({self.inner!r}, size={self.size})"
issue_deprecation_warning(
"The `width` attribute for `Array` is deprecated. Use `size` instead.",
version="0.20.31",
)
return self.size


class Field:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,7 +1624,7 @@ def test_array_construction() -> None:
{"row_id": "a", "data": [1, 2, 3]},
{"row_id": "b", "data": [2, 3, 4]},
]
schema = {"row_id": pl.String(), "data": pl.Array(inner=pl.Int64, width=3)}
schema = {"row_id": pl.String(), "data": pl.Array(inner=pl.Int64, shape=3)}
df = pl.from_dicts(rows, schema=schema)
assert df.schema == schema
assert df.rows() == [("a", [1, 2, 3]), ("b", [2, 3, 4])]
16 changes: 8 additions & 8 deletions py-polars/tests/unit/dataframe/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,19 @@ def test_df_serde_enum() -> None:
@pytest.mark.parametrize(
("data", "dtype"),
[
([[1, 2, 3], [None, None, None], [1, None, 3]], pl.Array(pl.Int32(), width=3)),
([["a", "b"], [None, None]], pl.Array(pl.Utf8, width=2)),
([[True, False, None], [None, None, None]], pl.Array(pl.Utf8, width=3)),
([[1, 2, 3], [None, None, None], [1, None, 3]], pl.Array(pl.Int32(), shape=3)),
([["a", "b"], [None, None]], pl.Array(pl.Utf8, shape=2)),
([[True, False, None], [None, None, None]], pl.Array(pl.Utf8, shape=3)),
(
[[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]],
pl.List(pl.Array(pl.Int32(), width=3)),
pl.List(pl.Array(pl.Int32(), shape=3)),
),
(
[
[datetime(1991, 1, 1), datetime(1991, 1, 1), None],
[None, None, None],
],
pl.Array(pl.Datetime, width=3),
pl.Array(pl.Datetime, shape=3),
),
],
)
Expand All @@ -112,18 +112,18 @@ def test_write_read_json_array(data: Any, dtype: pl.DataType) -> None:
],
[None, None],
],
pl.Array(pl.Datetime, width=2),
pl.Array(pl.Datetime, shape=2),
),
(
[[date(1997, 10, 1), date(2000, 1, 1)], [None, None]],
pl.Array(pl.Date, width=2),
pl.Array(pl.Date, shape=2),
),
(
[
[timedelta(seconds=1), timedelta(seconds=10)],
[None, None],
],
pl.Array(pl.Duration, width=2),
pl.Array(pl.Duration, shape=2),
),
],
)
Expand Down
31 changes: 28 additions & 3 deletions py-polars/tests/unit/datatypes/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_array_data_type_equality() -> None:
def test_cast_list_to_array(data: Any, inner_type: pl.DataType) -> None:
s = pl.Series(data, dtype=pl.List(inner_type))
s = s.cast(pl.Array(inner_type, 2))
assert s.dtype == pl.Array(inner_type, width=2)
assert s.dtype == pl.Array(inner_type, shape=2)
assert s.to_list() == data


Expand Down Expand Up @@ -259,7 +259,7 @@ def test_arr_median(data_dispersion: pl.DataFrame) -> None:


def test_array_repeat() -> None:
dtype = pl.Array(pl.UInt8, width=1)
dtype = pl.Array(pl.UInt8, shape=1)
s = pl.repeat([42], n=3, dtype=dtype, eager=True)
expected = pl.Series("repeat", [[42], [42], [42]], dtype=dtype)
assert s.dtype == dtype
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_array_ndarray_reshape() -> None:

def test_recursive_array_dtype() -> None:
assert str(pl.Array(pl.Int64, (2, 3))) == "Array(Int64, shape=(2, 3))"
assert str(pl.Array(pl.Int64, 3)) == "Array(Int64, size=3)"
assert str(pl.Array(pl.Int64, 3)) == "Array(Int64, shape=(3,))"
dtype = pl.Array(pl.Int64, 3)
s = pl.Series(np.arange(6).reshape((2, 3)), dtype=dtype)
assert s.dtype == dtype
Expand All @@ -303,3 +303,28 @@ def test_ndarray_construction() -> None:
s = pl.Series(a)
assert s.dtype == pl.Array(pl.Int64, (4, 2))
assert (s.to_numpy() == a).all()


def test_array_width_deprecated() -> None:
with pytest.deprecated_call():
dtype = pl.Array(pl.Int8, width=2)
with pytest.deprecated_call():
assert dtype.width == 2


def test_array_inner_recursive() -> None:
shape = (2, 3, 4, 5)
dtype = pl.Array(int, shape=shape)
for dim in shape:
assert dtype.size == dim
dtype = dtype.inner # type: ignore[assignment]


def test_array_inner_recursive_python_dtype() -> None:
dtype = pl.Array(int, shape=(2, 3))
assert dtype.inner.inner == pl.Int64 # type: ignore[union-attr]


def test_array_missing_shape() -> None:
with pytest.raises(TypeError):
pl.Array(pl.Int8)
8 changes: 4 additions & 4 deletions py-polars/tests/unit/functions/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def test_repeat_n_negative() -> None:
(2, ["1"], pl.List(pl.Utf8)),
(4, True, pl.Boolean),
(2, [True], pl.List(pl.Boolean)),
(2, [1], pl.Array(pl.Int16, width=1)),
(2, [1, 1, 1], pl.Array(pl.Int8, width=3)),
(2, [1], pl.Array(pl.Int16, shape=1)),
(2, [1, 1, 1], pl.Array(pl.Int8, shape=3)),
(1, [1], pl.List(pl.UInt32)),
],
)
Expand Down Expand Up @@ -126,8 +126,8 @@ def test_ones(
(2, ["0"], pl.List(pl.Utf8)),
(4, False, pl.Boolean),
(2, [False], pl.List(pl.Boolean)),
(3, [0], pl.Array(pl.UInt32, width=1)),
(2, [0, 0, 0], pl.Array(pl.UInt32, width=3)),
(3, [0], pl.Array(pl.UInt32, shape=1)),
(2, [0, 0, 0], pl.Array(pl.UInt32, shape=3)),
(1, [0], pl.List(pl.UInt32)),
],
)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def test_sliced_dict_with_nulls_14904() -> None:

def test_parquet_array_dtype() -> None:
df = pl.DataFrame({"x": [[1, 2, 3]]})
df = df.cast({"x": pl.Array(pl.Int64, width=3)})
df = df.cast({"x": pl.Array(pl.Int64, shape=3)})
test_round_trip(df)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_arr_to_list() -> None:
df = pl.DataFrame(
data,
schema={
"duration": pl.Array(pl.Datetime, width=2),
"duration": pl.Array(pl.Datetime, shape=2),
},
).with_columns(pl.col("duration").arr.to_list())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_contains() -> None:


def test_list_contains_invalid_datatype() -> None:
df = pl.DataFrame({"a": [[1, 2], [3, 4]]}, schema={"a": pl.Array(pl.Int8, width=2)})
df = pl.DataFrame({"a": [[1, 2], [3, 4]]}, schema={"a": pl.Array(pl.Int8, shape=2)})
with pytest.raises(pl.SchemaError, match="invalid series dtype: expected `List`"):
df.select(pl.col("a").list.contains(2))

Expand Down
3 changes: 2 additions & 1 deletion py-polars/tests/unit/series/test_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def test_eq_array_cmp_list() -> None:
def test_eq_array_cmp_int() -> None:
s = pl.Series([[1, 3], [1, 2]], dtype=pl.Array(pl.Int16, 2))
with pytest.raises(
TypeError, match="cannot convert Python type 'int' to Array\\(Int16, size=2\\)"
TypeError,
match="cannot convert Python type 'int' to Array\\(Int16, shape=\\(2,\\)\\)",
):
s == 1 # noqa: B015

Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,13 @@ def test_serde_keep_dtype_empty_list() -> None:
def test_serde_array_dtype() -> None:
s = pl.Series(
[[1, 2, 3], [None, None, None], [1, None, 3]],
dtype=pl.Array(pl.Int32(), width=3),
dtype=pl.Array(pl.Int32(), 3),
)
assert_series_equal(pickle.loads(pickle.dumps(s)), s)

nested_s = pl.Series(
[[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]],
dtype=pl.List(pl.Array(pl.Int32(), width=3)),
dtype=pl.List(pl.Array(pl.Int32(), 3)),
)
assert_series_equal(pickle.loads(pickle.dumps(nested_s)), nested_s)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_allow_infinities_deprecated(data: st.DataObject) -> None:
@given(
df=dataframes(
cols=[
column("colx", dtype=pl.Array(pl.UInt8, width=3)),
column("colx", dtype=pl.Array(pl.UInt8, shape=3)),
column("coly", dtype=pl.List(pl.Datetime("ms"))),
column(
name="colz",
Expand All @@ -208,7 +208,7 @@ def test_allow_infinities_deprecated(data: st.DataObject) -> None:
)
def test_dataframes_nested_strategies(df: pl.DataFrame) -> None:
assert df.schema == {
"colx": pl.Array(pl.UInt8, width=3),
"colx": pl.Array(pl.UInt8, shape=3),
"coly": pl.List(pl.Datetime("ms")),
"colz": pl.List(pl.List(pl.String)),
}
Expand Down

0 comments on commit f1be8d9

Please sign in to comment.