Skip to content

Commit

Permalink
Preserve DataFrame type after LazyFrame roundtrips (#2862)
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobGM authored Mar 10, 2022
1 parent fe28511 commit 0038bd2
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 40 deletions.
3 changes: 1 addition & 2 deletions py-polars/polars/internals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
via this __init__ file using `import polars.internals as pli`. The imports below are being shared across this module.
"""
from .expr import Expr, expr_to_lit_or_expr, selection_to_pyexpr_list, wrap_expr
from .frame import DataFrame, wrap_df
from .frame import DataFrame, LazyFrame, wrap_df, wrap_ldf
from .functions import concat, date_range # DataFrame.describe() & DataFrame.upsample()
from .lazy_frame import LazyFrame, wrap_ldf
from .lazy_functions import all, argsort_by, col, concat_list, lit, select
from .series import Series, wrap_s
from .whenthen import when # used in expr.clip()
115 changes: 85 additions & 30 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
series_to_pydf,
)

from .lazy_frame import LazyFrame, wrap_ldf # noqa: F401

try:
from polars.polars import PyDataFrame, PySeries

Expand Down Expand Up @@ -98,7 +100,75 @@ def _prepare_other_arg(other: Any) -> "pli.Series":
return other


class DataFrame:
class DataFrameMetaClass(type):
"""
Custom metaclass for DataFrame class.
This metaclass is responsible for constructing the relationship between the
DataFrame class and the LazyFrame class. Originally, without inheritance, the
relationship is as follows:
DataFrame <-> LazyFrame
This two-way relationship is represented by the following pointers:
- cls._lazyframe_class: A pointer on the DataFrame (sub)class to a LazyFrame
(sub)class. This class property can be used in DataFrame methods in order
to construct new lazy dataframes.
- cls._lazyframe_class._dataframe_class: A pointer on the LazyFrame (sub)class
back to the original DataFrame (sub)class. This allows LazyFrame methods to
construct new non-lazy dataframes with the correct type. This pointer should
always be set to cls such that the following is always `True`:
`type(cls) is type(cls.lazy().collect())`.
If an end user subclasses DataFrame like so:
>>> class MyDataFrame(pl.DataFrame):
... pass
...
Then the following class is dynamically created by the metaclass and saved on the
class variable `MyDataFrame._lazyframe_class`.
>>> class LazyMyDataFrame(pl.DataFrame):
... _dataframe_class = MyDataFrame
...
If an end user needs to extend both `DataFrame` and `LazyFrame`, it can be done like
so:
>>> class MyLazyFrame(pl.LazyFrame):
... @classmethod
... @property
... def _dataframe_class(cls):
... return MyDataFrame
...
>>> class MyDataFrame(pl.DataFrame):
... _lazyframe_class = MyLazyFrame
...
"""

def __init__(cls, name: str, bases: tuple, clsdict: dict) -> None:
"""Construct new DataFrame class."""
if not bases:
# This is not a subclass of DataFrame and we can simply hard-link to
# LazyFrame instead of dynamically defining a new subclass of LazyFrame.
cls._lazyframe_class = LazyFrame
elif cls._lazyframe_class is LazyFrame:
# This is a subclass of DataFrame which has *not* specified a custom
# LazyFrame subclass by setting `cls._lazyframe_class`. We must therefore
# dynamically create a subclass of LazyFrame with `_dataframe_class` set
# to `cls` in order to preserve types after `.lazy().collect()` roundtrips.
cls._lazyframe_class = type( # type: ignore
f"Lazy{name}",
(LazyFrame,),
{"_dataframe_class": cls},
)
super().__init__(name, bases, clsdict)


class DataFrame(metaclass=DataFrameMetaClass):
"""
A DataFrame is a two-dimensional data structure that represents data as a table
with rows and columns.
Expand Down Expand Up @@ -1610,9 +1680,7 @@ def rename(self: DF, mapping: Dict[str, str]) -> DF:
└───────┴─────┴─────┘
"""
return self._from_pydf(
self.lazy().rename(mapping).collect(no_optimization=True)._df
)
return self.lazy().rename(mapping).collect(no_optimization=True)

def insert_at_idx(self, index: int, series: "pli.Series") -> None:
"""
Expand Down Expand Up @@ -1674,11 +1742,10 @@ def filter(self: DF, predicate: "pli.Expr") -> DF:
└─────┴─────┴─────┘
"""
return self._from_pydf(
return (
self.lazy()
.filter(predicate)
.collect(no_optimization=True, string_cache=False)
._df
)

@property
Expand Down Expand Up @@ -2028,12 +2095,11 @@ def sort(
self.lazy()
.sort(by, reverse)
.collect(no_optimization=True, string_cache=False)
._df
)
if in_place:
self._df = df
self._df = df._df
return self
return self._from_pydf(df)
return df
if in_place:
self._df.sort_in_place(by, reverse)
return None
Expand Down Expand Up @@ -3067,7 +3133,7 @@ def join_asof(
force_parallel
Force the physical plan to evaluate the computation of both DataFrames up to the join in parallel.
"""
return self._from_pydf(
return (
self.lazy()
.join_asof(
df.lazy(),
Expand All @@ -3084,7 +3150,6 @@ def join_asof(
force_parallel=force_parallel,
)
.collect(no_optimization=True)
._df
)

def join(
Expand Down Expand Up @@ -3217,7 +3282,7 @@ def join(
or asof_by_right is not None
or asof_by is not None
):
return self._from_pydf(
return (
self.lazy()
.join(
df.lazy(),
Expand All @@ -3231,7 +3296,6 @@ def join(
asof_by=asof_by,
)
.collect(no_optimization=True)
._df
)
else:
return self._from_pydf(
Expand Down Expand Up @@ -3603,9 +3667,7 @@ def fill_null(self: DF, strategy: Union[str, "pli.Expr", Any]) -> DF:
DataFrame with None values replaced by the filling strategy.
"""
if isinstance(strategy, pli.Expr):
return self._from_pydf(
self.lazy().fill_null(strategy).collect(no_optimization=True)._df
)
return self.lazy().fill_null(strategy).collect(no_optimization=True)
if not isinstance(strategy, str):
return self.fill_null(pli.lit(strategy))
return self._from_pydf(self._df.fill_null(strategy))
Expand All @@ -3628,9 +3690,7 @@ def fill_nan(self: DF, fill_value: Union["pli.Expr", int, float]) -> DF:
-------
DataFrame with NaN replaced with fill_value
"""
return self._from_pydf(
self.lazy().fill_nan(fill_value).collect(no_optimization=True)._df
)
return self.lazy().fill_nan(fill_value).collect(no_optimization=True)

def explode(
self: DF,
Expand Down Expand Up @@ -3702,9 +3762,7 @@ def explode(
└─────────┴─────┘
"""
return self._from_pydf(
self.lazy().explode(columns).collect(no_optimization=True)._df
)
return self.lazy().explode(columns).collect(no_optimization=True)

def pivot(
self: DF,
Expand Down Expand Up @@ -3926,11 +3984,10 @@ def shift_and_fill(
└─────┴─────┴─────┘
"""
return self._from_pydf(
return (
self.lazy()
.shift_and_fill(periods, fill_value)
.collect(no_optimization=True, string_cache=False)
._df
)

def is_duplicated(self) -> "pli.Series":
Expand Down Expand Up @@ -3985,7 +4042,7 @@ def is_unique(self) -> "pli.Series":
"""
return pli.wrap_s(self._df.is_unique())

def lazy(self) -> "pli.LazyFrame":
def lazy(self: DF) -> "pli.LazyFrame[DF]":
"""
Start a lazy query from this point. This returns a `LazyFrame` object.
Expand All @@ -3999,7 +4056,7 @@ def lazy(self) -> "pli.LazyFrame":
Lazy operations are advised because they allow for query optimization and more parallelization.
"""
return pli.wrap_ldf(self._df.lazy())
return self._lazyframe_class._from_pyldf(self._df.lazy())

def select(
self: DF,
Expand Down Expand Up @@ -4042,11 +4099,10 @@ def select(
└─────┘
"""
return self._from_pydf(
return (
self.lazy()
.select(exprs) # type: ignore
.collect(no_optimization=True, string_cache=False)
._df
)

def with_columns(self: DF, exprs: Union["pli.Expr", List["pli.Expr"]]) -> DF:
Expand All @@ -4060,11 +4116,10 @@ def with_columns(self: DF, exprs: Union["pli.Expr", List["pli.Expr"]]) -> DF:
"""
if not isinstance(exprs, list):
exprs = [exprs]
return self._from_pydf(
return (
self.lazy()
.with_columns(exprs)
.collect(no_optimization=True, string_cache=False)
._df
)

def n_chunks(self) -> int:
Expand Down
37 changes: 30 additions & 7 deletions py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Any,
Callable,
Dict,
Generic,
List,
Optional,
Sequence,
Expand All @@ -35,6 +36,11 @@
# including sub-classes.
LDF = TypeVar("LDF", bound="LazyFrame")

# We redefine the DF type variable from polars.internals.frame here in order to prevent
# circular import issues. The frame module needs this module to be defined at import
# time due to how the metaclass of DataFrame is defined.
DF = TypeVar("DF", bound="pli.DataFrame")


def wrap_ldf(ldf: "PyLazyFrame") -> "LazyFrame":
return LazyFrame._from_pyldf(ldf)
Expand All @@ -58,7 +64,7 @@ def _prepare_groupby_inputs(
return new_by


class LazyFrame:
class LazyFrame(Generic[DF]):
"""
Representation of a Lazy computation graph/ query.
"""
Expand All @@ -72,6 +78,23 @@ def _from_pyldf(cls: Type[LDF], ldf: "PyLazyFrame") -> LDF:
self._ldf = ldf
return self

@property
def _dataframe_class(self) -> Type[DF]:
"""
Return the associated DataFrame which is the equivalent of this LazyFrame object.
This class is used when a LazyFrame object is casted to a non-lazy representation
by the invocation of `.collect()`, `.fetch()`, and so on. By default we specify
the regular `polars.internals.frame.DataFrame` class here, but any subclass of
DataFrame that wishes to preserve its type when converted to LazyFrame and back
(with `.lazy().collect()` for instance) must overwrite this class variable
before setting DataFrame._lazyframe_class.
This property is dynamically overwritten when DataFrame is sub-classed. See
`polars.internals.frame.DataFrameMetaClass.__init__` for implementation details.
"""
return pli.DataFrame # type: ignore

@classmethod
def scan_csv(
cls: Type[LDF],
Expand Down Expand Up @@ -393,7 +416,7 @@ def collect(
string_cache: bool = False,
no_optimization: bool = False,
slice_pushdown: bool = True,
) -> pli.DataFrame:
) -> DF:
"""
Collect into a DataFrame.
Expand Down Expand Up @@ -439,7 +462,7 @@ def collect(
string_cache,
slice_pushdown,
)
return pli.wrap_df(ldf.collect())
return self._dataframe_class._from_pydf(ldf.collect())

def fetch(
self,
Expand All @@ -451,7 +474,7 @@ def fetch(
string_cache: bool = True,
no_optimization: bool = False,
slice_pushdown: bool = True,
) -> pli.DataFrame:
) -> DF:
"""
Fetch is like a collect operation, but it overwrites the number of rows read by every scan
operation. This is a utility that helps debug a query on a smaller number of rows.
Expand Down Expand Up @@ -497,7 +520,7 @@ def fetch(
string_cache,
slice_pushdown,
)
return pli.wrap_df(ldf.fetch(n_rows))
return self._dataframe_class._from_pydf(ldf.fetch(n_rows))

@property
def columns(self) -> List[str]:
Expand Down Expand Up @@ -1852,7 +1875,7 @@ def melt(

def map(
self: LDF,
f: Callable[[pli.DataFrame], pli.DataFrame],
f: Callable[["pli.DataFrame"], "pli.DataFrame"],
predicate_pushdown: bool = True,
projection_pushdown: bool = True,
no_optimizations: bool = False,
Expand Down Expand Up @@ -2077,7 +2100,7 @@ def tail(self, n: int = 5) -> "LazyFrame":
"""
return wrap_ldf(self.lgb.tail(n))

def apply(self, f: Callable[[pli.DataFrame], pli.DataFrame]) -> "LazyFrame":
def apply(self, f: Callable[["pli.DataFrame"], "pli.DataFrame"]) -> "LazyFrame":
"""
Apply a function over the groups as a new `DataFrame`. It is not recommended that you use
this as materializing the `DataFrame` is quite expensive.
Expand Down
18 changes: 17 additions & 1 deletion py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from builtins import range
from datetime import datetime
from io import BytesIO
from typing import Any, Iterator
from typing import Any, Iterator, Type
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -1897,3 +1897,19 @@ class SubClassedDataFrame(pl.DataFrame):
# Methods which yield new dataframes should preserve the subclass,
# and here we choose a random method to test with
assert isinstance(df.transpose(), SubClassedDataFrame)

# The type of the dataframe should be preserved when casted to LazyFrame and back
assert isinstance(df.lazy().collect(), SubClassedDataFrame)

# Check if the end user can extend the functionality of both DataFrame and LazyFrame
# and connect these classes together
class MyLazyFrame(pl.LazyFrame):
@property
def _dataframe_class(cls) -> "Type[MyDataFrame]":
return MyDataFrame

class MyDataFrame(pl.DataFrame):
_lazyframe_class = MyLazyFrame

assert isinstance(MyDataFrame().lazy(), MyLazyFrame)
assert isinstance(MyDataFrame().lazy().collect(), MyDataFrame)

0 comments on commit 0038bd2

Please sign in to comment.