From 39ce9e03e7abcd5b371baa16b5453b69cf6228a1 Mon Sep 17 00:00:00 2001 From: Tom Barber Date: Thu, 28 Mar 2024 18:42:07 +0000 Subject: [PATCH] Initial Polars Lazyframe Support (#775) This PR is to aid for support of Polars LazyFrames in Hamilton. * Extended applicable types for Polars writers The applicable types for PolarsCSVWriter, PolarsParquetWriter, and PolarsFeatherWriter have been extended to include pl.LazyFrame in addition to the existing pl.DataFrame. This change allows these writer classes to handle both eager and lazy data frames from the polars library. * Updated PolarsLazyFrameResult and data writers The PolarsLazyFrameResult class now uses the PolarsLazyFrameResult instead of the PolarsDataFrameResult. The logging statement in register_types() has been removed. DataSaver classes have been updated to handle both DATAFRAME_TYPE and pl.LazyFrame types, with a check added to collect data if it's a LazyFrame before saving. Tests have been updated and expanded to cover these changes, including checks for applicable types and correct handling of LazyFrames. * Extended support for LazyFrame in Polars extensions The applicable_types method in the PolarsSpreadsheetWriter class and corresponding test assertions have been updated to include pl.LazyFrame, along with the existing DATAFRAME_TYPE. This change extends the functionality of our Polars extensions to handle LazyFrames as well as DataFrames. * Added Polars LazyFrame example This update introduces a new example demonstrating the use of Polars LazyFrame. The changes include: - Creation of two new Python scripts: one defining functions for loading data and calculating spend per signup, and another script to execute these functions. - Addition of a README file explaining how to run the example, visualize execution, and detailing some caveats with Polars. - Inclusion of a requirements.txt file specifying necessary dependencies. - Addition of sample CSV data for testing purposes. * Updated data loading method in tests The test methods for PolarsScanParquetReader and PolarsScanFeatherReader have been updated. Instead of using pl.DataFrame to load data, they now use pl.LazyFrame. This change aligns with the applicable types for these readers. Notes I've also had to update the get_dataframe_metadata in utils.py to allow it to work with Lazyframes that don't have a row count. I abstracted all the lookups so that if others passed/failed in the future for support of other read/writers they would return what they can. --------- Co-authored-by: Tom Barber --- examples/polars/lazyframe/README.md | 37 +++ examples/polars/lazyframe/my_functions.py | 15 + examples/polars/lazyframe/my_script.py | 33 +++ examples/polars/lazyframe/requirements.txt | 2 + examples/polars/lazyframe/sample_data.csv | 7 + hamilton/function_modifiers/base.py | 1 + hamilton/io/utils.py | 29 +- hamilton/plugins/h_polars.py | 2 + hamilton/plugins/h_polars_lazyframe.py | 46 +++ hamilton/plugins/polars_extensions.py | 56 ++-- .../plugins/polars_lazyframe_extensions.py | 273 ++++++++++++++++++ tests/plugins/test_polars_extensions.py | 14 +- .../test_polars_lazyframe_extensions.py | 177 ++++++++++++ 13 files changed, 656 insertions(+), 36 deletions(-) create mode 100644 examples/polars/lazyframe/README.md create mode 100644 examples/polars/lazyframe/my_functions.py create mode 100644 examples/polars/lazyframe/my_script.py create mode 100644 examples/polars/lazyframe/requirements.txt create mode 100644 examples/polars/lazyframe/sample_data.csv create mode 100644 hamilton/plugins/h_polars_lazyframe.py create mode 100644 hamilton/plugins/polars_lazyframe_extensions.py create mode 100644 tests/plugins/test_polars_lazyframe_extensions.py diff --git a/examples/polars/lazyframe/README.md b/examples/polars/lazyframe/README.md new file mode 100644 index 000000000..daad1f010 --- /dev/null +++ b/examples/polars/lazyframe/README.md @@ -0,0 +1,37 @@ +# Classic Hamilton Hello World + +In this example we show you how to create a simple hello world dataflow that +creates a polars lazyframe as a result. It performs a series of transforms on the +input to create columns that appear in the output. + +File organization: + +* `my_functions.py` houses the logic that we want to compute. +Note (1) how the functions are named, and what input +parameters they require. That is how we create a DAG modeling the dataflow we want to happen. +* `my_script.py` houses how to get Hamilton to create the DAG, specifying that we want a polars dataframe and +exercise it with some inputs. + +To run things: +```bash +> python my_script.py +``` + +# Visualizing Execution +Here is the graph of execution - which should look the same as the pandas example: + +![polars](polars.png) + +# Caveat with Polars +There is one major caveat with Polars to be aware of: THERE IS NO INDEX IN POLARS LIKE THERE IS WITH PANDAS. + +What this means is that when you tell Hamilton to execute and return a polars dataframe if you are using the +[provided results builder](https://github.com/dagworks-inc/hamilton/blob/sf-hamilton-1.14.1/hamilton/plugins/h_polars.py#L8), i.e. `hamilton.plugins.h_polars.PolarsResultsBuilder`, then you will have to +ensure the row order matches the order you expect for all the outputs you request. E.g. if you do a filter, or a sort, +or a join, or a groupby, you will have to ensure that when you ask Hamilton to materialize an output that it's in the +order you expect. + +If you have questions, or need help with this example, +join us on [slack](https://join.slack.com/t/hamilton-opensource/shared_invite/zt-1bjs72asx-wcUTgH7q7QX1igiQ5bbdcg), and we'll try to help! + +Otherwise if you have ideas on how to better make Hamilton work with Polars, please open an issue or start a discussion! diff --git a/examples/polars/lazyframe/my_functions.py b/examples/polars/lazyframe/my_functions.py new file mode 100644 index 000000000..cc1d6ef62 --- /dev/null +++ b/examples/polars/lazyframe/my_functions.py @@ -0,0 +1,15 @@ +import polars as pl + +from hamilton.function_modifiers import load_from, value + + +@load_from.csv(file=value("./sample_data.csv")) +def raw_data(data: pl.LazyFrame) -> pl.LazyFrame: + return data + + +def spend_per_signup(raw_data: pl.LazyFrame) -> pl.LazyFrame: + """Computes cost per signup in relation to spend.""" + return raw_data.select("spend", "signups").with_columns( + [(pl.col("spend") / pl.col("signups")).alias("spend_per_signup")] + ) diff --git a/examples/polars/lazyframe/my_script.py b/examples/polars/lazyframe/my_script.py new file mode 100644 index 000000000..0e0fe4ad9 --- /dev/null +++ b/examples/polars/lazyframe/my_script.py @@ -0,0 +1,33 @@ +import logging +import sys + +from hamilton import base, driver +from hamilton.plugins import h_polars_lazyframe + +logging.basicConfig(stream=sys.stdout) + +# Create a driver instance. If you want to run the compute in the final node you can also use +# h_polars.PolarsDataFrameResult() and you don't need to run collect at the end. Which you use +# probably depends on whether you want to use the LazyFrame in more nodes in another DAG before +# computing the result. +adapter = base.SimplePythonGraphAdapter(result_builder=h_polars_lazyframe.PolarsLazyFrameResult()) +import my_functions # where our functions are defined + +dr = driver.Driver({}, my_functions, adapter=adapter) +output_columns = [ + "spend_per_signup", +] +# let's create the lazyframe! +df = dr.execute(output_columns) +# Here we just print the Lazyframe plan +print(df) + +# Now we run the query +df = df.collect() + +# And print the table. +print(df) + +# To visualize do `pip install "sf-hamilton[visualization]"` if you want these to work +# dr.visualize_execution(output_columns, './polars', {"format": "png"}) +# dr.display_all_functions('./my_full_dag.dot') diff --git a/examples/polars/lazyframe/requirements.txt b/examples/polars/lazyframe/requirements.txt new file mode 100644 index 000000000..3874d6b99 --- /dev/null +++ b/examples/polars/lazyframe/requirements.txt @@ -0,0 +1,2 @@ +polars +sf-hamilton diff --git a/examples/polars/lazyframe/sample_data.csv b/examples/polars/lazyframe/sample_data.csv new file mode 100644 index 000000000..d7fde73b5 --- /dev/null +++ b/examples/polars/lazyframe/sample_data.csv @@ -0,0 +1,7 @@ +signups,spend +1,10 +10,10 +50,20 +100,40 +200,40 +400,50 diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index 7ae2a34a0..be3038443 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -26,6 +26,7 @@ "pandas", "plotly", "polars", + "polars_lazyframe", "pyspark_pandas", "spark", "dask", diff --git a/hamilton/io/utils.py b/hamilton/io/utils.py index ffa4546f4..95dd9588d 100644 --- a/hamilton/io/utils.py +++ b/hamilton/io/utils.py @@ -62,14 +62,27 @@ def get_dataframe_metadata(df: pd.DataFrame) -> Dict[str, Any]: - the column names - the data types """ - return { - DATAFRAME_METADATA: { - "rows": len(df), - "columns": len(df.columns), - "column_names": list(df.columns), - "datatypes": [str(t) for t in list(df.dtypes)], # for serialization purposes - } - } + metadata = {} + try: + metadata["rows"] = len(df) + except TypeError: + metadata["rows"] = None + + try: + metadata["columns"] = len(df.columns) + except (AttributeError, TypeError): + metadata["columns"] = None + + try: + metadata["column_names"] = list(df.columns) + except (AttributeError, TypeError): + metadata["column_names"] = None + + try: + metadata["datatypes"] = [str(t) for t in list(df.dtypes)] + except (AttributeError, TypeError): + metadata["datatypes"] = None + return {DATAFRAME_METADATA: metadata} def get_file_and_dataframe_metadata(path: str, df: pd.DataFrame) -> Dict[str, Any]: diff --git a/hamilton/plugins/h_polars.py b/hamilton/plugins/h_polars.py index b2f847b86..8d910ad75 100644 --- a/hamilton/plugins/h_polars.py +++ b/hamilton/plugins/h_polars.py @@ -40,6 +40,8 @@ def build_result( (value,) = outputs.values() # this works because it's length 1. if isinstance(value, pl.DataFrame): # it's a dataframe return value + if isinstance(value, pl.LazyFrame): # it's a lazyframe + return value.collect() elif not isinstance(value, pl.Series): # it's a single scalar/object key, value = outputs.popitem() return pl.DataFrame({key: [value]}) diff --git a/hamilton/plugins/h_polars_lazyframe.py b/hamilton/plugins/h_polars_lazyframe.py new file mode 100644 index 000000000..9d019f2c5 --- /dev/null +++ b/hamilton/plugins/h_polars_lazyframe.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, Type, Union + +import polars as pl + +from hamilton import base + + +class PolarsLazyFrameResult(base.ResultMixin): + """A ResultBuilder that produces a polars dataframe. + + Use this when you want to create a polars dataframe from the outputs. Caveat: you need to ensure that the length + of the outputs is the same, otherwise you will get an error; mixed outputs aren't that well handled. + + To use: + + .. code-block:: python + + from hamilton import base, driver + from hamilton.plugins import polars_extensions + polars_builder = polars_extensions.PolarsLazyFrameResult() + adapter = base.SimplePythonGraphAdapter(polars_builder) + dr = driver.Driver(config, *modules, adapter=adapter) + df = dr.execute([...], inputs=...) # returns polars dataframe + + Note: this is just a first attempt at something for Polars. Think it should handle more? Come chat/open a PR! + """ + + def build_result( + self, **outputs: Dict[str, Union[pl.Series, pl.LazyFrame, Any]] + ) -> pl.LazyFrame: + """This is the method that Hamilton will call to build the final result. It will pass in the results + of the requested outputs that you passed in to the execute() method. + + Note: this function could do smarter things; looking for contributions here! + + :param outputs: The results of the requested outputs. + :return: a polars DataFrame. + """ + if len(outputs) == 1: + (value,) = outputs.values() # this works because it's length 1. + if isinstance(value, pl.LazyFrame): # it's a lazyframe + return value + return pl.LazyFrame(outputs) + + def output_type(self) -> Type: + return pl.LazyFrame diff --git a/hamilton/plugins/polars_extensions.py b/hamilton/plugins/polars_extensions.py index 67ce124a3..c903ba8b6 100644 --- a/hamilton/plugins/polars_extensions.py +++ b/hamilton/plugins/polars_extensions.py @@ -166,8 +166,6 @@ def _get_loading_kwargs(self): kwargs["row_count_name"] = self.row_count_name if self.row_count_offset is not None: kwargs["row_count_offset"] = self.row_count_offset - if self.sample_size is not None: - kwargs["sample_size"] = self.sample_size if self.eol_char is not None: kwargs["eol_char"] = self.eol_char if self.raise_if_empty is not None: @@ -176,6 +174,7 @@ def _get_loading_kwargs(self): def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: df = pl.read_csv(self.file, **self._get_loading_kwargs()) + metadata = utils.get_file_and_dataframe_metadata(self.file, df) return df, metadata @@ -206,7 +205,7 @@ class PolarsCSVWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -236,15 +235,12 @@ def _get_saving_kwargs(self): kwargs["quote_style"] = self.quote_style return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() data.write_csv(self.file, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.file, data) - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: - df = pl.read_csv(self.file, **self._get_loading_kwargs()) - metadata = utils.get_file_and_dataframe_metadata(self.file, df) - return df, metadata - @classmethod def name(cls) -> str: return "csv" @@ -330,7 +326,7 @@ class PolarsParquetWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -348,8 +344,12 @@ def _get_saving_kwargs(self): kwargs["pyarrow_options"] = self.pyarrow_options return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + data.write_parquet(self.file, **self._get_saving_kwargs()) + return utils.get_file_and_dataframe_metadata(self.file, data) @classmethod @@ -422,7 +422,7 @@ class PolarsFeatherWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -430,7 +430,9 @@ def _get_saving_kwargs(self): kwargs["compression"] = self.compression return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() data.write_ipc(self.file, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.file, data) @@ -484,7 +486,7 @@ class PolarsAvroWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -492,7 +494,10 @@ def _get_saving_kwargs(self): kwargs["compression"] = self.compression return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + data.write_avro(self.file, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.file, data) @@ -547,7 +552,7 @@ class PolarsJSONWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -557,7 +562,10 @@ def _get_saving_kwargs(self): kwargs["row_oriented"] = self.row_oriented return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + data.write_json(self.file, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.file, data) @@ -665,7 +673,7 @@ class PolarsSpreadsheetWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -713,7 +721,10 @@ def _get_saving_kwargs(self): kwargs["freeze_panes"] = self.freeze_panes return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + data.write_excel(self.workbook, self.worksheet, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.workbook, data) @@ -782,7 +793,7 @@ class PolarsDatabaseWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -792,7 +803,10 @@ def _get_saving_kwargs(self): kwargs["engine"] = self.engine return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + data.write_database( table_name=self.table_name, connection=self.connection, diff --git a/hamilton/plugins/polars_lazyframe_extensions.py b/hamilton/plugins/polars_lazyframe_extensions.py new file mode 100644 index 000000000..fdabb8460 --- /dev/null +++ b/hamilton/plugins/polars_lazyframe_extensions.py @@ -0,0 +1,273 @@ +import dataclasses +from io import BytesIO +from pathlib import Path +from typing import ( + Any, + BinaryIO, + Collection, + Dict, + List, + Mapping, + Optional, + Sequence, + TextIO, + Tuple, + Type, + Union, +) + +try: + import polars as pl + from polars import PolarsDataType +except ImportError: + raise NotImplementedError("Polars is not installed.") + + +# for polars <0.16.0 we need to determine whether type_aliases exist. +has_alias = False +if hasattr(pl, "type_aliases"): + has_alias = True + +# for polars 0.18.0 we need to check what to do. +if has_alias and hasattr(pl.type_aliases, "CsvEncoding"): + from polars.type_aliases import CsvEncoding +else: + CsvEncoding = Type + + +from hamilton import registry +from hamilton.io import utils +from hamilton.io.data_adapters import DataLoader + +DATAFRAME_TYPE = pl.LazyFrame +COLUMN_TYPE = None +COLUMN_FRIENDLY_DF_TYPE = False + + +def register_types(): + """Function to register the types for this extension.""" + registry.register_types("polars_lazyframe", DATAFRAME_TYPE, COLUMN_TYPE) + + +register_types() + + +@dataclasses.dataclass +class PolarsScanCSVReader(DataLoader): + """Class specifically to handle loading CSV files with Polars. + Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_csv.html + """ + + file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + # kwargs: + has_header: bool = True + columns: Union[Sequence[int], Sequence[str]] = None + new_columns: Sequence[str] = None + separator: str = "," + comment_char: str = None + quote_char: str = '"' + skip_rows: int = 0 + dtypes: Union[Mapping[str, PolarsDataType], Sequence[PolarsDataType]] = None + null_values: Union[str, Sequence[str], Dict[str, str]] = None + missing_utf8_is_empty_string: bool = False + ignore_errors: bool = False + try_parse_dates: bool = False + n_threads: int = None + infer_schema_length: int = 100 + batch_size: int = 8192 + n_rows: int = None + encoding: Union[CsvEncoding, str] = "utf8" + low_memory: bool = False + rechunk: bool = True + use_pyarrow: bool = False + storage_options: Dict[str, Any] = None + skip_rows_after_header: int = 0 + row_count_name: str = None + row_count_offset: int = 0 + eol_char: str = "\n" + raise_if_empty: bool = True + + def _get_loading_kwargs(self): + kwargs = {} + if self.has_header is not None: + kwargs["has_header"] = self.has_header + if self.columns is not None: + kwargs["columns"] = self.columns + if self.new_columns is not None: + kwargs["new_columns"] = self.new_columns + if self.separator is not None: + kwargs["separator"] = self.separator + if self.comment_char is not None: + kwargs["comment_char"] = self.comment_char + if self.quote_char is not None: + kwargs["quote_char"] = self.quote_char + if self.skip_rows is not None: + kwargs["skip_rows"] = self.skip_rows + if self.dtypes is not None: + kwargs["dtypes"] = self.dtypes + if self.null_values is not None: + kwargs["null_values"] = self.null_values + if self.missing_utf8_is_empty_string is not None: + kwargs["missing_utf8_is_empty_string"] = self.missing_utf8_is_empty_string + if self.ignore_errors is not None: + kwargs["ignore_errors"] = self.ignore_errors + if self.try_parse_dates is not None: + kwargs["try_parse_dates"] = self.try_parse_dates + if self.n_threads is not None: + kwargs["n_threads"] = self.n_threads + if self.infer_schema_length is not None: + kwargs["infer_schema_length"] = self.infer_schema_length + if self.n_rows is not None: + kwargs["n_rows"] = self.n_rows + if self.encoding is not None: + kwargs["encoding"] = self.encoding + if self.low_memory is not None: + kwargs["low_memory"] = self.low_memory + if self.rechunk is not None: + kwargs["rechunk"] = self.rechunk + if self.storage_options is not None: + kwargs["storage_options"] = self.storage_options + if self.skip_rows_after_header is not None: + kwargs["skip_rows_after_header"] = self.skip_rows_after_header + if self.row_count_name is not None: + kwargs["row_count_name"] = self.row_count_name + if self.row_count_offset is not None: + kwargs["row_count_offset"] = self.row_count_offset + if self.eol_char is not None: + kwargs["eol_char"] = self.eol_char + if self.raise_if_empty is not None: + kwargs["raise_if_empty"] = self.raise_if_empty + return kwargs + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + df = pl.scan_csv(self.file, **self._get_loading_kwargs()) + + metadata = utils.get_file_and_dataframe_metadata(self.file, df) + return df, metadata + + @classmethod + def name(cls) -> str: + return "csv" + + +@dataclasses.dataclass +class PolarsScanParquetReader(DataLoader): + """Class specifically to handle loading parquet files with polars + Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_parquet.html + """ + + file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + # kwargs: + columns: Union[List[int], List[str]] = None + n_rows: int = None + use_pyarrow: bool = False + memory_map: bool = True + storage_options: Dict[str, Any] = None + parallel: Any = "auto" + row_count_name: str = None + row_count_offset: int = 0 + low_memory: bool = False + use_statistics: bool = True + rechunk: bool = True + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def _get_loading_kwargs(self): + kwargs = {} + if self.columns is not None: + kwargs["columns"] = self.columns + if self.n_rows is not None: + kwargs["n_rows"] = self.n_rows + if self.storage_options is not None: + kwargs["storage_options"] = self.storage_options + if self.parallel is not None: + kwargs["parallel"] = self.parallel + if self.row_count_name is not None: + kwargs["row_count_name"] = self.row_count_name + if self.row_count_offset is not None: + kwargs["row_count_offset"] = self.row_count_offset + if self.low_memory is not None: + kwargs["low_memory"] = self.low_memory + if self.use_statistics is not None: + kwargs["use_statistics"] = self.use_statistics + if self.rechunk is not None: + kwargs["rechunk"] = self.rechunk + return kwargs + + def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + df = pl.scan_parquet(self.file, **self._get_loading_kwargs()) + metadata = utils.get_file_and_dataframe_metadata(self.file, df) + return df, metadata + + @classmethod + def name(cls) -> str: + return "parquet" + + +@dataclasses.dataclass +class PolarsScanFeatherReader(DataLoader): + """ + Class specifically to handle loading Feather/Arrow IPC files with Polars. + Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_ipc.html + """ + + source: Union[str, BinaryIO, BytesIO, Path, bytes] + # kwargs: + columns: Optional[Union[List[str], List[int]]] = None + n_rows: Optional[int] = None + use_pyarrow: bool = False + memory_map: bool = True + storage_options: Optional[Dict[str, Any]] = None + row_count_name: Optional[str] = None + row_count_offset: int = 0 + rechunk: bool = True + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def _get_loading_kwargs(self): + kwargs = {} + if self.columns is not None: + kwargs["columns"] = self.columns + if self.n_rows is not None: + kwargs["n_rows"] = self.n_rows + if self.memory_map is not None: + kwargs["memory_map"] = self.memory_map + if self.storage_options is not None: + kwargs["storage_options"] = self.storage_options + if self.row_count_name is not None: + kwargs["row_count_name"] = self.row_count_name + if self.row_count_offset is not None: + kwargs["row_count_offset"] = self.row_count_offset + if self.rechunk is not None: + kwargs["rechunk"] = self.rechunk + return kwargs + + def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + df = pl.scan_ipc(self.source, **self._get_loading_kwargs()) + metadata = utils.get_file_metadata(self.source) + return df, metadata + + @classmethod + def name(cls) -> str: + return "feather" + + +def register_data_loaders(): + """Function to register the data loaders for this extension.""" + for loader in [ + PolarsScanCSVReader, + PolarsScanParquetReader, + PolarsScanFeatherReader, + ]: + registry.register_adapter(loader) + + +register_data_loaders() diff --git a/tests/plugins/test_polars_extensions.py b/tests/plugins/test_polars_extensions.py index bace4d130..e4db44d13 100644 --- a/tests/plugins/test_polars_extensions.py +++ b/tests/plugins/test_polars_extensions.py @@ -45,7 +45,7 @@ def test_polars_csv(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: kwargs2 = reader._get_loading_kwargs() df2, metadata = reader.load_data(pl.DataFrame) - assert PolarsCSVWriter.applicable_types() == [pl.DataFrame] + assert PolarsCSVWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsCSVReader.applicable_types() == [pl.DataFrame] assert kwargs1["separator"] == "," assert kwargs2["has_header"] is True @@ -63,7 +63,7 @@ def test_polars_parquet(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: kwargs2 = reader._get_loading_kwargs() df2, metadata = reader.load_data(pl.DataFrame) - assert PolarsParquetWriter.applicable_types() == [pl.DataFrame] + assert PolarsParquetWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsParquetReader.applicable_types() == [pl.DataFrame] assert kwargs1["compression"] == "zstd" assert kwargs2["n_rows"] == 2 @@ -85,7 +85,7 @@ def test_polars_feather(tmp_path: pathlib.Path) -> None: assert "n_rows" not in read_kwargs assert df.shape == (4, 3) - assert PolarsFeatherWriter.applicable_types() == [pl.DataFrame] + assert PolarsFeatherWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert "compression" in write_kwargs assert file_path.exists() assert metadata["file_metadata"]["path"] == str(file_path) @@ -103,7 +103,7 @@ def test_polars_json(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: kwargs2 = reader._get_loading_kwargs() df2, metadata = reader.load_data(pl.DataFrame) - assert PolarsJSONWriter.applicable_types() == [pl.DataFrame] + assert PolarsJSONWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsJSONReader.applicable_types() == [pl.DataFrame] assert kwargs1["pretty"] assert df2.shape == (2, 2) @@ -122,7 +122,7 @@ def test_polars_avro(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: kwargs2 = reader._get_loading_kwargs() df2, metadata = reader.load_data(pl.DataFrame) - assert PolarsAvroWriter.applicable_types() == [pl.DataFrame] + assert PolarsAvroWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsAvroReader.applicable_types() == [pl.DataFrame] assert kwargs1["compression"] == "uncompressed" assert kwargs2["n_rows"] == 2 @@ -145,7 +145,7 @@ def test_polars_database(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: kwargs2 = reader._get_loading_kwargs() df2, metadata = reader.load_data(pl.DataFrame) - assert PolarsDatabaseWriter.applicable_types() == [pl.DataFrame] + assert PolarsDatabaseWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsDatabaseReader.applicable_types() == [pl.DataFrame] assert kwargs1["if_table_exists"] == "replace" assert "batch_size" not in kwargs2 @@ -163,7 +163,7 @@ def test_polars_spreadsheet(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: read_kwargs = reader._get_loading_kwargs() df2, _ = reader.load_data(pl.DataFrame) - assert PolarsSpreadsheetWriter.applicable_types() == [pl.DataFrame] + assert PolarsSpreadsheetWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsSpreadsheetReader.applicable_types() == [pl.DataFrame] assert file_path.exists() assert metadata["file_metadata"]["path"] == str(file_path) diff --git a/tests/plugins/test_polars_lazyframe_extensions.py b/tests/plugins/test_polars_lazyframe_extensions.py new file mode 100644 index 000000000..c775ad228 --- /dev/null +++ b/tests/plugins/test_polars_lazyframe_extensions.py @@ -0,0 +1,177 @@ +import pathlib +import sys + +import polars as pl +import pytest +from polars.testing import assert_frame_equal + +from hamilton.plugins.polars_extensions import ( + PolarsAvroReader, + PolarsAvroWriter, + PolarsCSVWriter, + PolarsDatabaseReader, + PolarsDatabaseWriter, + PolarsFeatherWriter, + PolarsJSONReader, + PolarsJSONWriter, + PolarsParquetWriter, + PolarsSpreadsheetReader, + PolarsSpreadsheetWriter, +) +from hamilton.plugins.polars_lazyframe_extensions import ( + PolarsScanCSVReader, + PolarsScanFeatherReader, + PolarsScanParquetReader, +) + + +@pytest.fixture +def df(): + yield pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) + + +def test_lazy_polars_lazyframe_csv(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test.csv" + + writer = PolarsCSVWriter(file=file) + kwargs1 = writer._get_saving_kwargs() + writer.save_data(df) + + reader = PolarsScanCSVReader(file=file) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.LazyFrame) + + assert PolarsCSVWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsScanCSVReader.applicable_types() == [pl.LazyFrame] + assert kwargs1["separator"] == "," + assert kwargs2["has_header"] is True + assert_frame_equal(df.collect(), df2.collect()) + + +def test_lazy_polars_parquet(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test.parquet" + + writer = PolarsParquetWriter(file=file) + kwargs1 = writer._get_saving_kwargs() + writer.save_data(df) + + reader = PolarsScanParquetReader(file=file, n_rows=2) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.LazyFrame) + + assert PolarsParquetWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsScanParquetReader.applicable_types() == [pl.LazyFrame] + assert kwargs1["compression"] == "zstd" + assert kwargs2["n_rows"] == 2 + assert_frame_equal(df.collect(), df2.collect()) + + +def test_lazy_polars_feather(tmp_path: pathlib.Path) -> None: + test_data_file_path = "tests/resources/data/test_load_from_data.feather" + reader = PolarsScanFeatherReader(source=test_data_file_path) + read_kwargs = reader._get_loading_kwargs() + df, _ = reader.load_data(pl.LazyFrame) + + file_path = tmp_path / "test.dta" + writer = PolarsFeatherWriter(file=file_path) + write_kwargs = writer._get_saving_kwargs() + metadata = writer.save_data(df.collect()) + + assert PolarsScanFeatherReader.applicable_types() == [pl.LazyFrame] + assert "n_rows" not in read_kwargs + assert df.collect().shape == (4, 3) + + assert PolarsFeatherWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert "compression" in write_kwargs + assert file_path.exists() + assert metadata["file_metadata"]["path"] == str(file_path) + assert metadata["dataframe_metadata"]["column_names"] == [ + "animal", + "points", + "environment", + ] + assert metadata["dataframe_metadata"]["datatypes"] == ["String", "Int64", "String"] + + +def test_lazy_polars_avro(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test.avro" + + writer = PolarsAvroWriter(file=file) + kwargs1 = writer._get_saving_kwargs() + writer.save_data(df) + + reader = PolarsAvroReader(file=file, n_rows=2) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.DataFrame) + + assert PolarsAvroWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsAvroReader.applicable_types() == [pl.DataFrame] + assert kwargs1["compression"] == "uncompressed" + assert kwargs2["n_rows"] == 2 + assert_frame_equal(df.collect(), df2) + + +def test_polars_json(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test.json" + writer = PolarsJSONWriter(file=file, pretty=True) + kwargs1 = writer._get_saving_kwargs() + writer.save_data(df) + + reader = PolarsJSONReader(source=file) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.DataFrame) + + assert PolarsJSONWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsJSONReader.applicable_types() == [pl.DataFrame] + assert kwargs1["pretty"] + assert df2.shape == (2, 2) + assert "schema" not in kwargs2 + assert_frame_equal(df.collect(), df2) + + +@pytest.mark.skipif( + sys.version_info.major == 3 and sys.version_info.minor == 12, + reason="weird connectorx error on 3.12", +) +def test_polars_database(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + conn = f"sqlite:///{tmp_path}/test.db" + table_name = "test_table" + + writer = PolarsDatabaseWriter(table_name=table_name, connection=conn, if_table_exists="replace") + kwargs1 = writer._get_saving_kwargs() + writer.save_data(df) + + reader = PolarsDatabaseReader(query=f"SELECT * FROM {table_name}", connection=conn) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.DataFrame) + + assert PolarsDatabaseWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsDatabaseReader.applicable_types() == [pl.DataFrame] + assert kwargs1["if_table_exists"] == "replace" + assert "batch_size" not in kwargs2 + assert df2.shape == (2, 2) + assert_frame_equal(df.collect(), df2) + + +def test_polars_spreadsheet(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file_path = tmp_path / "test.xlsx" + writer = PolarsSpreadsheetWriter(workbook=file_path, worksheet="test_load_from_data_sheet") + write_kwargs = writer._get_saving_kwargs() + metadata = writer.save_data(df) + + reader = PolarsSpreadsheetReader(source=file_path, sheet_name="test_load_from_data_sheet") + read_kwargs = reader._get_loading_kwargs() + df2, _ = reader.load_data(pl.DataFrame) + + assert PolarsSpreadsheetWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsSpreadsheetReader.applicable_types() == [pl.DataFrame] + assert file_path.exists() + assert metadata["file_metadata"]["path"] == str(file_path) + assert df.collect().shape == (2, 2) + assert metadata["dataframe_metadata"]["column_names"] == ["a", "b"] + assert metadata["dataframe_metadata"]["datatypes"] == ["Int64", "Int64"] + assert_frame_equal(df.collect(), df2) + assert "include_header" in write_kwargs + assert write_kwargs["include_header"] is True + assert "raise_if_empty" in read_kwargs + assert read_kwargs["raise_if_empty"] is True