Skip to content

Commit

Permalink
feat(python): Various Schema improvements (new base_types method,…
Browse files Browse the repository at this point in the history
… improved equality/init dtype checks)
  • Loading branch information
alexander-beedie committed Oct 22, 2024
1 parent 27289b2 commit da188e1
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 61 deletions.
2 changes: 2 additions & 0 deletions py-polars/polars/_reexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from polars.dataframe import DataFrame
from polars.expr import Expr, When
from polars.lazyframe import LazyFrame
from polars.schema import Schema
from polars.series import Series

__all__ = [
"DataFrame",
"Expr",
"LazyFrame",
"Schema",
"Series",
"When",
]
8 changes: 6 additions & 2 deletions py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,14 @@


def is_polars_dtype(
dtype: Any, *, include_unknown: bool = False
dtype: Any,
*,
include_unknown: bool = False,
require_instantiated: bool = False,
) -> TypeGuard[PolarsDataType]:
"""Indicate whether the given input is a Polars dtype, or dtype specialization."""
is_dtype = isinstance(dtype, (DataType, DataTypeClass))
check_classes = DataType if require_instantiated else (DataType, DataTypeClass)
is_dtype = isinstance(dtype, check_classes) # type: ignore[arg-type]

if not include_unknown:
return is_dtype and dtype != Unknown
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def _read_spreadsheet(
infer_schema_length=infer_schema_length,
)
engine_options = (engine_options or {}).copy()
schema_overrides = dict(schema_overrides or {})
schema_overrides = pl.Schema(schema_overrides or {})

# establish the reading function, parser, and available worksheets
reader_fn, parser, worksheets = _initialise_spreadsheet_parser(
Expand Down
60 changes: 53 additions & 7 deletions py-polars/polars/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,28 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING

from polars.datatypes import DataType
from polars.datatypes import DataType, is_polars_dtype
from polars.datatypes._parse import parse_into_dtype

BaseSchema = OrderedDict[str, DataType]

if TYPE_CHECKING:
from collections.abc import Iterable

from polars._typing import PythonDataType
from polars.datatypes import DataTypeClass


BaseSchema = OrderedDict[str, DataType]

__all__ = ["Schema"]


def _check_nested_dtype(tp: DataType | DataTypeClass) -> bool:
if tp.is_nested() and not isinstance(tp, DataType):
msg = f"nested dtypes must be fully-specified, got: {tp!r}"
raise TypeError(msg)
return True


class Schema(BaseSchema):
"""
Ordered mapping of column names to their data type.
Expand Down Expand Up @@ -62,11 +70,49 @@ def __init__(
input = (
schema.items() if schema and isinstance(schema, Mapping) else (schema or {})
)
super().__init__({name: parse_into_dtype(tp) for name, tp in input}) # type: ignore[misc]

def __setitem__(self, name: str, dtype: DataType | PythonDataType) -> None:
for name, tp in input: # type: ignore[misc]
if is_polars_dtype(tp) and _check_nested_dtype(tp):
super().__setitem__(name, tp) # type: ignore[assignment]
else:
self[name] = tp

def __eq__(self, other: object) -> bool:
if not isinstance(other, Mapping):
return False
if len(self) != len(other):
return False
for (nm1, tp1), (nm2, tp2) in zip(self.items(), other.items()):
if nm1 != nm2 or not tp1.is_(tp2):
return False
return True

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __setitem__(
self, name: str, dtype: DataType | DataTypeClass | PythonDataType
) -> None:
_check_nested_dtype(dtype := parse_into_dtype(dtype))
super().__setitem__(name, parse_into_dtype(dtype)) # type: ignore[assignment]

def base_types(self) -> dict[str, DataTypeClass]:
"""
Return a dictionary of column names and the fundamental/root type class.
Examples
--------
>>> s = pl.Schema(
... {
... "x": pl.Float64(),
... "y": pl.List(pl.Int32),
... "z": pl.Struct([pl.Field("a", pl.Int8), pl.Field("b", pl.Boolean)]),
... }
... )
>>> s.base_types()
{'x': Float64, 'y': List, 'z': Struct}
"""
return {name: tp.base_type() for name, tp in self.items()}

def names(self) -> list[str]:
"""Get the column names of the schema."""
return list(self.keys())
Expand All @@ -81,7 +127,7 @@ def len(self) -> int:

def to_python(self) -> dict[str, type]:
"""
Return Schema as a dictionary of column names and their Python types.
Return a dictionary of column names and Python types.
Examples
--------
Expand Down
5 changes: 2 additions & 3 deletions py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_init_dict() -> None:
data={"dt": dates, "dtm": datetimes},
schema=coldefs,
)
assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime}
assert df.schema == {"dt": pl.Date, "dtm": pl.Datetime("us")}
assert df.rows() == list(zip(py_dates, py_datetimes))

# Overriding dict column names/types
Expand Down Expand Up @@ -251,7 +251,7 @@ class TradeNT(NamedTuple):
)
assert df.schema == {
"ts": pl.Datetime("ms"),
"tk": pl.Categorical,
"tk": pl.Categorical(ordering="physical"),
"pc": pl.Decimal(scale=1),
"sz": pl.UInt16,
}
Expand Down Expand Up @@ -284,7 +284,6 @@ class PageView(BaseModel):
models = adapter.validate_json(data_json)

result = pl.DataFrame(models)

expected = pl.DataFrame(
{
"user_id": ["x"],
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_dtype() -> None:
"u": pl.List(pl.UInt64),
"tm": pl.List(pl.Time),
"dt": pl.List(pl.Date),
"dtm": pl.List(pl.Datetime),
"dtm": pl.List(pl.Datetime("us")),
}
assert all(tp.is_nested() for tp in df.dtypes)
assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined]
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_empty_list_construction() -> None:
assert df.to_dict(as_series=False) == expected

df = pl.DataFrame(schema=[("col", pl.List)])
assert df.schema == {"col": pl.List}
assert df.schema == {"col": pl.List(pl.Null)}
assert df.rows() == []


Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def build_struct_df(data: list[dict[str, object]]) -> pl.DataFrame:
# struct column
df = build_struct_df([{"struct_col": {"inner": 1}}])
assert df.columns == ["struct_col"]
assert df.schema == {"struct_col": pl.Struct}
assert df.schema == {"struct_col": pl.Struct({"inner": pl.Int64})}
assert df["struct_col"].struct.field("inner").to_list() == [1]

# struct in struct
Expand Down
23 changes: 4 additions & 19 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,16 +632,7 @@ def test_asof_join() -> None:
"2016-05-25 13:30:00.072",
"2016-05-25 13:30:00.075",
]
ticker = [
"GOOG",
"MSFT",
"MSFT",
"MSFT",
"GOOG",
"AAPL",
"GOOG",
"MSFT",
]
ticker = ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"]
quotes = pl.DataFrame(
{
"dates": pl.Series(dates).str.strptime(pl.Datetime, format=format),
Expand All @@ -656,13 +647,7 @@ def test_asof_join() -> None:
"2016-05-25 13:30:00.048",
"2016-05-25 13:30:00.048",
]
ticker = [
"MSFT",
"MSFT",
"GOOG",
"GOOG",
"AAPL",
]
ticker = ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"]
trades = pl.DataFrame(
{
"dates": pl.Series(dates).str.strptime(pl.Datetime, format=format),
Expand All @@ -678,11 +663,11 @@ def test_asof_join() -> None:
out = trades.join_asof(quotes, on="dates", strategy="backward")

assert out.schema == {
"bid": pl.Float64,
"bid_right": pl.Float64,
"dates": pl.Datetime("ms"),
"ticker": pl.String,
"bid": pl.Float64,
"ticker_right": pl.String,
"bid_right": pl.Float64,
}
assert out.columns == ["dates", "ticker", "bid", "ticker_right", "bid_right"]
assert (out["dates"].cast(int)).to_list() == [
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/interop/test_from_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_from_pandas() -> None:
"floats_nulls": pl.Float64,
"strings": pl.String,
"strings_nulls": pl.String,
"strings-cat": pl.Categorical,
"strings-cat": pl.Categorical(ordering="physical"),
}
assert out.rows() == [
(False, None, 1, 1.0, 1.0, 1.0, "foo", "foo", "foo"),
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_from_dict_struct() -> None:
assert df.shape == (2, 2)
assert df["a"][0] == {"b": 1, "c": 2}
assert df["a"][1] == {"b": 3, "c": 4}
assert df.schema == {"a": pl.Struct, "d": pl.Int64}
assert df.schema == {"a": pl.Struct({"b": pl.Int64, "c": pl.Int64}), "d": pl.Int64}


def test_from_dicts() -> None:
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_dataframe_from_repr() -> None:
assert frame.schema == {
"a": pl.Int64,
"b": pl.Float64,
"c": pl.Categorical,
"c": pl.Categorical(ordering="physical"),
"d": pl.Boolean,
"e": pl.String,
"f": pl.Date,
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def test_group_by_sorted_empty_dataframe_3680() -> None:
)
assert df.rows() == []
assert df.shape == (0, 2)
assert df.schema == {"key": pl.Categorical, "val": pl.Float64}
assert df.schema == {"key": pl.Categorical(ordering="physical"), "val": pl.Float64}


def test_group_by_custom_agg_empty_list() -> None:
Expand Down
66 changes: 48 additions & 18 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pickle
from datetime import datetime

import pytest

import polars as pl


Expand All @@ -13,24 +15,28 @@ def test_schema() -> None:
assert s.names() == ["foo", "bar"]
assert s.dtypes() == [pl.Int8(), pl.String()]

with pytest.raises(
TypeError,
match="nested dtypes must be fully-specified, got: List",
):
pl.Schema({"foo": pl.String, "bar": pl.List})

def test_schema_parse_nonpolars_dtypes() -> None:
cardinal_directions = pl.Enum(["north", "south", "east", "west"])

s = pl.Schema({"foo": pl.List, "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type]
s["ham"] = datetime

assert s["foo"] == pl.List
assert s["bar"] == pl.Int64
assert s["baz"] == cardinal_directions
assert s["ham"] == pl.Datetime("us")

assert s.len() == 4
assert s.names() == ["foo", "bar", "baz", "ham"]
assert s.dtypes() == [pl.List, pl.Int64, cardinal_directions, pl.Datetime("us")]

assert list(s.to_python().values()) == [list, int, str, datetime]
assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime]
def test_schema_base_types() -> None:
s = pl.Schema(
{
"a": pl.Int8(),
"b": pl.Datetime("us"),
"c": pl.Array(pl.Int8(), shape=(4,)),
"d": pl.Struct({"time": pl.List(pl.Duration), "dist": pl.Float64}),
}
)
assert s.base_types() == {
"a": pl.Int8,
"b": pl.Datetime,
"c": pl.Array,
"d": pl.Struct,
}


def test_schema_equality() -> None:
Expand All @@ -45,6 +51,32 @@ def test_schema_equality() -> None:
assert s1 != s3
assert s2 != s3

s4 = pl.Schema({"foo": pl.Datetime("us"), "bar": pl.Duration("ns")})
s5 = pl.Schema({"foo": pl.Datetime("ns"), "bar": pl.Duration("us")})
s6 = {"foo": pl.Datetime, "bar": pl.Duration}

assert s4 != s5
assert s4 != s6


def test_schema_parse_python_dtypes() -> None:
cardinal_directions = pl.Enum(["north", "south", "east", "west"])

s = pl.Schema({"foo": pl.List(pl.Int32), "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type]
s["ham"] = datetime

assert s["foo"] == pl.List(pl.Int32)
assert s["bar"] == pl.Int64
assert s["baz"] == cardinal_directions
assert s["ham"] == pl.Datetime("us")

assert s.len() == 4
assert s.names() == ["foo", "bar", "baz", "ham"]
assert s.dtypes() == [pl.List, pl.Int64, cardinal_directions, pl.Datetime("us")]

assert list(s.to_python().values()) == [list, int, str, datetime]
assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime]


def test_schema_picklable() -> None:
s = pl.Schema(
Expand Down Expand Up @@ -88,7 +120,6 @@ def test_schema_in_map_elements_returns_scalar() -> None:
"amounts": [100.0, -110.0] * 2,
}
)

q = ldf.group_by("portfolio").agg(
pl.col("amounts")
.map_elements(
Expand All @@ -112,7 +143,6 @@ def test_schema_functions_in_agg_with_literal_arg_19011() -> None:
.rolling(index_column=pl.int_range(pl.len()).alias("idx"), period="3i")
.agg(pl.col("a").fill_null(0).alias("a_1"), pl.col("a").pow(2.0).alias("a_2"))
)

assert q.collect_schema() == pl.Schema(
[("idx", pl.Int64), ("a_1", pl.List(pl.Int64)), ("a_2", pl.List(pl.Float64))]
)
Loading

0 comments on commit da188e1

Please sign in to comment.