diff --git a/examples/polars/lazyframe/README.md b/examples/polars/lazyframe/README.md new file mode 100644 index 000000000..daad1f010 --- /dev/null +++ b/examples/polars/lazyframe/README.md @@ -0,0 +1,37 @@ +# Classic Hamilton Hello World + +In this example we show you how to create a simple hello world dataflow that +creates a polars lazyframe as a result. It performs a series of transforms on the +input to create columns that appear in the output. + +File organization: + +* `my_functions.py` houses the logic that we want to compute. +Note (1) how the functions are named, and what input +parameters they require. That is how we create a DAG modeling the dataflow we want to happen. +* `my_script.py` houses how to get Hamilton to create the DAG, specifying that we want a polars dataframe and +exercise it with some inputs. + +To run things: +```bash +> python my_script.py +``` + +# Visualizing Execution +Here is the graph of execution - which should look the same as the pandas example: + +![polars](polars.png) + +# Caveat with Polars +There is one major caveat with Polars to be aware of: THERE IS NO INDEX IN POLARS LIKE THERE IS WITH PANDAS. + +What this means is that when you tell Hamilton to execute and return a polars dataframe if you are using the +[provided results builder](https://github.com/dagworks-inc/hamilton/blob/sf-hamilton-1.14.1/hamilton/plugins/h_polars.py#L8), i.e. `hamilton.plugins.h_polars.PolarsResultsBuilder`, then you will have to +ensure the row order matches the order you expect for all the outputs you request. E.g. if you do a filter, or a sort, +or a join, or a groupby, you will have to ensure that when you ask Hamilton to materialize an output that it's in the +order you expect. + +If you have questions, or need help with this example, +join us on [slack](https://join.slack.com/t/hamilton-opensource/shared_invite/zt-1bjs72asx-wcUTgH7q7QX1igiQ5bbdcg), and we'll try to help! + +Otherwise if you have ideas on how to better make Hamilton work with Polars, please open an issue or start a discussion! diff --git a/examples/polars/lazyframe/my_functions.py b/examples/polars/lazyframe/my_functions.py new file mode 100644 index 000000000..cc1d6ef62 --- /dev/null +++ b/examples/polars/lazyframe/my_functions.py @@ -0,0 +1,15 @@ +import polars as pl + +from hamilton.function_modifiers import load_from, value + + +@load_from.csv(file=value("./sample_data.csv")) +def raw_data(data: pl.LazyFrame) -> pl.LazyFrame: + return data + + +def spend_per_signup(raw_data: pl.LazyFrame) -> pl.LazyFrame: + """Computes cost per signup in relation to spend.""" + return raw_data.select("spend", "signups").with_columns( + [(pl.col("spend") / pl.col("signups")).alias("spend_per_signup")] + ) diff --git a/examples/polars/lazyframe/my_script.py b/examples/polars/lazyframe/my_script.py new file mode 100644 index 000000000..0e0fe4ad9 --- /dev/null +++ b/examples/polars/lazyframe/my_script.py @@ -0,0 +1,33 @@ +import logging +import sys + +from hamilton import base, driver +from hamilton.plugins import h_polars_lazyframe + +logging.basicConfig(stream=sys.stdout) + +# Create a driver instance. If you want to run the compute in the final node you can also use +# h_polars.PolarsDataFrameResult() and you don't need to run collect at the end. Which you use +# probably depends on whether you want to use the LazyFrame in more nodes in another DAG before +# computing the result. +adapter = base.SimplePythonGraphAdapter(result_builder=h_polars_lazyframe.PolarsLazyFrameResult()) +import my_functions # where our functions are defined + +dr = driver.Driver({}, my_functions, adapter=adapter) +output_columns = [ + "spend_per_signup", +] +# let's create the lazyframe! +df = dr.execute(output_columns) +# Here we just print the Lazyframe plan +print(df) + +# Now we run the query +df = df.collect() + +# And print the table. +print(df) + +# To visualize do `pip install "sf-hamilton[visualization]"` if you want these to work +# dr.visualize_execution(output_columns, './polars', {"format": "png"}) +# dr.display_all_functions('./my_full_dag.dot') diff --git a/examples/polars/lazyframe/requirements.txt b/examples/polars/lazyframe/requirements.txt new file mode 100644 index 000000000..3874d6b99 --- /dev/null +++ b/examples/polars/lazyframe/requirements.txt @@ -0,0 +1,2 @@ +polars +sf-hamilton diff --git a/examples/polars/lazyframe/sample_data.csv b/examples/polars/lazyframe/sample_data.csv new file mode 100644 index 000000000..d7fde73b5 --- /dev/null +++ b/examples/polars/lazyframe/sample_data.csv @@ -0,0 +1,7 @@ +signups,spend +1,10 +10,10 +50,20 +100,40 +200,40 +400,50 diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index 7ae2a34a0..be3038443 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -26,6 +26,7 @@ "pandas", "plotly", "polars", + "polars_lazyframe", "pyspark_pandas", "spark", "dask", diff --git a/hamilton/io/utils.py b/hamilton/io/utils.py index ffa4546f4..95dd9588d 100644 --- a/hamilton/io/utils.py +++ b/hamilton/io/utils.py @@ -62,14 +62,27 @@ def get_dataframe_metadata(df: pd.DataFrame) -> Dict[str, Any]: - the column names - the data types """ - return { - DATAFRAME_METADATA: { - "rows": len(df), - "columns": len(df.columns), - "column_names": list(df.columns), - "datatypes": [str(t) for t in list(df.dtypes)], # for serialization purposes - } - } + metadata = {} + try: + metadata["rows"] = len(df) + except TypeError: + metadata["rows"] = None + + try: + metadata["columns"] = len(df.columns) + except (AttributeError, TypeError): + metadata["columns"] = None + + try: + metadata["column_names"] = list(df.columns) + except (AttributeError, TypeError): + metadata["column_names"] = None + + try: + metadata["datatypes"] = [str(t) for t in list(df.dtypes)] + except (AttributeError, TypeError): + metadata["datatypes"] = None + return {DATAFRAME_METADATA: metadata} def get_file_and_dataframe_metadata(path: str, df: pd.DataFrame) -> Dict[str, Any]: diff --git a/hamilton/plugins/h_polars.py b/hamilton/plugins/h_polars.py index b2f847b86..8d910ad75 100644 --- a/hamilton/plugins/h_polars.py +++ b/hamilton/plugins/h_polars.py @@ -40,6 +40,8 @@ def build_result( (value,) = outputs.values() # this works because it's length 1. if isinstance(value, pl.DataFrame): # it's a dataframe return value + if isinstance(value, pl.LazyFrame): # it's a lazyframe + return value.collect() elif not isinstance(value, pl.Series): # it's a single scalar/object key, value = outputs.popitem() return pl.DataFrame({key: [value]}) diff --git a/hamilton/plugins/h_polars_lazyframe.py b/hamilton/plugins/h_polars_lazyframe.py new file mode 100644 index 000000000..9d019f2c5 --- /dev/null +++ b/hamilton/plugins/h_polars_lazyframe.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, Type, Union + +import polars as pl + +from hamilton import base + + +class PolarsLazyFrameResult(base.ResultMixin): + """A ResultBuilder that produces a polars dataframe. + + Use this when you want to create a polars dataframe from the outputs. Caveat: you need to ensure that the length + of the outputs is the same, otherwise you will get an error; mixed outputs aren't that well handled. + + To use: + + .. code-block:: python + + from hamilton import base, driver + from hamilton.plugins import polars_extensions + polars_builder = polars_extensions.PolarsLazyFrameResult() + adapter = base.SimplePythonGraphAdapter(polars_builder) + dr = driver.Driver(config, *modules, adapter=adapter) + df = dr.execute([...], inputs=...) # returns polars dataframe + + Note: this is just a first attempt at something for Polars. Think it should handle more? Come chat/open a PR! + """ + + def build_result( + self, **outputs: Dict[str, Union[pl.Series, pl.LazyFrame, Any]] + ) -> pl.LazyFrame: + """This is the method that Hamilton will call to build the final result. It will pass in the results + of the requested outputs that you passed in to the execute() method. + + Note: this function could do smarter things; looking for contributions here! + + :param outputs: The results of the requested outputs. + :return: a polars DataFrame. + """ + if len(outputs) == 1: + (value,) = outputs.values() # this works because it's length 1. + if isinstance(value, pl.LazyFrame): # it's a lazyframe + return value + return pl.LazyFrame(outputs) + + def output_type(self) -> Type: + return pl.LazyFrame diff --git a/hamilton/plugins/polars_extensions.py b/hamilton/plugins/polars_extensions.py index 67ce124a3..c903ba8b6 100644 --- a/hamilton/plugins/polars_extensions.py +++ b/hamilton/plugins/polars_extensions.py @@ -166,8 +166,6 @@ def _get_loading_kwargs(self): kwargs["row_count_name"] = self.row_count_name if self.row_count_offset is not None: kwargs["row_count_offset"] = self.row_count_offset - if self.sample_size is not None: - kwargs["sample_size"] = self.sample_size if self.eol_char is not None: kwargs["eol_char"] = self.eol_char if self.raise_if_empty is not None: @@ -176,6 +174,7 @@ def _get_loading_kwargs(self): def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: df = pl.read_csv(self.file, **self._get_loading_kwargs()) + metadata = utils.get_file_and_dataframe_metadata(self.file, df) return df, metadata @@ -206,7 +205,7 @@ class PolarsCSVWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -236,15 +235,12 @@ def _get_saving_kwargs(self): kwargs["quote_style"] = self.quote_style return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() data.write_csv(self.file, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.file, data) - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: - df = pl.read_csv(self.file, **self._get_loading_kwargs()) - metadata = utils.get_file_and_dataframe_metadata(self.file, df) - return df, metadata - @classmethod def name(cls) -> str: return "csv" @@ -330,7 +326,7 @@ class PolarsParquetWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -348,8 +344,12 @@ def _get_saving_kwargs(self): kwargs["pyarrow_options"] = self.pyarrow_options return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + data.write_parquet(self.file, **self._get_saving_kwargs()) + return utils.get_file_and_dataframe_metadata(self.file, data) @classmethod @@ -422,7 +422,7 @@ class PolarsFeatherWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -430,7 +430,9 @@ def _get_saving_kwargs(self): kwargs["compression"] = self.compression return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() data.write_ipc(self.file, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.file, data) @@ -484,7 +486,7 @@ class PolarsAvroWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -492,7 +494,10 @@ def _get_saving_kwargs(self): kwargs["compression"] = self.compression return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + data.write_avro(self.file, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.file, data) @@ -547,7 +552,7 @@ class PolarsJSONWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -557,7 +562,10 @@ def _get_saving_kwargs(self): kwargs["row_oriented"] = self.row_oriented return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + data.write_json(self.file, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.file, data) @@ -665,7 +673,7 @@ class PolarsSpreadsheetWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -713,7 +721,10 @@ def _get_saving_kwargs(self): kwargs["freeze_panes"] = self.freeze_panes return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + data.write_excel(self.workbook, self.worksheet, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.workbook, data) @@ -782,7 +793,7 @@ class PolarsDatabaseWriter(DataSaver): @classmethod def applicable_types(cls) -> Collection[Type]: - return [DATAFRAME_TYPE] + return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): kwargs = {} @@ -792,7 +803,10 @@ def _get_saving_kwargs(self): kwargs["engine"] = self.engine return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + data.write_database( table_name=self.table_name, connection=self.connection, diff --git a/hamilton/plugins/polars_lazyframe_extensions.py b/hamilton/plugins/polars_lazyframe_extensions.py new file mode 100644 index 000000000..fdabb8460 --- /dev/null +++ b/hamilton/plugins/polars_lazyframe_extensions.py @@ -0,0 +1,273 @@ +import dataclasses +from io import BytesIO +from pathlib import Path +from typing import ( + Any, + BinaryIO, + Collection, + Dict, + List, + Mapping, + Optional, + Sequence, + TextIO, + Tuple, + Type, + Union, +) + +try: + import polars as pl + from polars import PolarsDataType +except ImportError: + raise NotImplementedError("Polars is not installed.") + + +# for polars <0.16.0 we need to determine whether type_aliases exist. +has_alias = False +if hasattr(pl, "type_aliases"): + has_alias = True + +# for polars 0.18.0 we need to check what to do. +if has_alias and hasattr(pl.type_aliases, "CsvEncoding"): + from polars.type_aliases import CsvEncoding +else: + CsvEncoding = Type + + +from hamilton import registry +from hamilton.io import utils +from hamilton.io.data_adapters import DataLoader + +DATAFRAME_TYPE = pl.LazyFrame +COLUMN_TYPE = None +COLUMN_FRIENDLY_DF_TYPE = False + + +def register_types(): + """Function to register the types for this extension.""" + registry.register_types("polars_lazyframe", DATAFRAME_TYPE, COLUMN_TYPE) + + +register_types() + + +@dataclasses.dataclass +class PolarsScanCSVReader(DataLoader): + """Class specifically to handle loading CSV files with Polars. + Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_csv.html + """ + + file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + # kwargs: + has_header: bool = True + columns: Union[Sequence[int], Sequence[str]] = None + new_columns: Sequence[str] = None + separator: str = "," + comment_char: str = None + quote_char: str = '"' + skip_rows: int = 0 + dtypes: Union[Mapping[str, PolarsDataType], Sequence[PolarsDataType]] = None + null_values: Union[str, Sequence[str], Dict[str, str]] = None + missing_utf8_is_empty_string: bool = False + ignore_errors: bool = False + try_parse_dates: bool = False + n_threads: int = None + infer_schema_length: int = 100 + batch_size: int = 8192 + n_rows: int = None + encoding: Union[CsvEncoding, str] = "utf8" + low_memory: bool = False + rechunk: bool = True + use_pyarrow: bool = False + storage_options: Dict[str, Any] = None + skip_rows_after_header: int = 0 + row_count_name: str = None + row_count_offset: int = 0 + eol_char: str = "\n" + raise_if_empty: bool = True + + def _get_loading_kwargs(self): + kwargs = {} + if self.has_header is not None: + kwargs["has_header"] = self.has_header + if self.columns is not None: + kwargs["columns"] = self.columns + if self.new_columns is not None: + kwargs["new_columns"] = self.new_columns + if self.separator is not None: + kwargs["separator"] = self.separator + if self.comment_char is not None: + kwargs["comment_char"] = self.comment_char + if self.quote_char is not None: + kwargs["quote_char"] = self.quote_char + if self.skip_rows is not None: + kwargs["skip_rows"] = self.skip_rows + if self.dtypes is not None: + kwargs["dtypes"] = self.dtypes + if self.null_values is not None: + kwargs["null_values"] = self.null_values + if self.missing_utf8_is_empty_string is not None: + kwargs["missing_utf8_is_empty_string"] = self.missing_utf8_is_empty_string + if self.ignore_errors is not None: + kwargs["ignore_errors"] = self.ignore_errors + if self.try_parse_dates is not None: + kwargs["try_parse_dates"] = self.try_parse_dates + if self.n_threads is not None: + kwargs["n_threads"] = self.n_threads + if self.infer_schema_length is not None: + kwargs["infer_schema_length"] = self.infer_schema_length + if self.n_rows is not None: + kwargs["n_rows"] = self.n_rows + if self.encoding is not None: + kwargs["encoding"] = self.encoding + if self.low_memory is not None: + kwargs["low_memory"] = self.low_memory + if self.rechunk is not None: + kwargs["rechunk"] = self.rechunk + if self.storage_options is not None: + kwargs["storage_options"] = self.storage_options + if self.skip_rows_after_header is not None: + kwargs["skip_rows_after_header"] = self.skip_rows_after_header + if self.row_count_name is not None: + kwargs["row_count_name"] = self.row_count_name + if self.row_count_offset is not None: + kwargs["row_count_offset"] = self.row_count_offset + if self.eol_char is not None: + kwargs["eol_char"] = self.eol_char + if self.raise_if_empty is not None: + kwargs["raise_if_empty"] = self.raise_if_empty + return kwargs + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + df = pl.scan_csv(self.file, **self._get_loading_kwargs()) + + metadata = utils.get_file_and_dataframe_metadata(self.file, df) + return df, metadata + + @classmethod + def name(cls) -> str: + return "csv" + + +@dataclasses.dataclass +class PolarsScanParquetReader(DataLoader): + """Class specifically to handle loading parquet files with polars + Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_parquet.html + """ + + file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + # kwargs: + columns: Union[List[int], List[str]] = None + n_rows: int = None + use_pyarrow: bool = False + memory_map: bool = True + storage_options: Dict[str, Any] = None + parallel: Any = "auto" + row_count_name: str = None + row_count_offset: int = 0 + low_memory: bool = False + use_statistics: bool = True + rechunk: bool = True + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def _get_loading_kwargs(self): + kwargs = {} + if self.columns is not None: + kwargs["columns"] = self.columns + if self.n_rows is not None: + kwargs["n_rows"] = self.n_rows + if self.storage_options is not None: + kwargs["storage_options"] = self.storage_options + if self.parallel is not None: + kwargs["parallel"] = self.parallel + if self.row_count_name is not None: + kwargs["row_count_name"] = self.row_count_name + if self.row_count_offset is not None: + kwargs["row_count_offset"] = self.row_count_offset + if self.low_memory is not None: + kwargs["low_memory"] = self.low_memory + if self.use_statistics is not None: + kwargs["use_statistics"] = self.use_statistics + if self.rechunk is not None: + kwargs["rechunk"] = self.rechunk + return kwargs + + def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + df = pl.scan_parquet(self.file, **self._get_loading_kwargs()) + metadata = utils.get_file_and_dataframe_metadata(self.file, df) + return df, metadata + + @classmethod + def name(cls) -> str: + return "parquet" + + +@dataclasses.dataclass +class PolarsScanFeatherReader(DataLoader): + """ + Class specifically to handle loading Feather/Arrow IPC files with Polars. + Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_ipc.html + """ + + source: Union[str, BinaryIO, BytesIO, Path, bytes] + # kwargs: + columns: Optional[Union[List[str], List[int]]] = None + n_rows: Optional[int] = None + use_pyarrow: bool = False + memory_map: bool = True + storage_options: Optional[Dict[str, Any]] = None + row_count_name: Optional[str] = None + row_count_offset: int = 0 + rechunk: bool = True + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def _get_loading_kwargs(self): + kwargs = {} + if self.columns is not None: + kwargs["columns"] = self.columns + if self.n_rows is not None: + kwargs["n_rows"] = self.n_rows + if self.memory_map is not None: + kwargs["memory_map"] = self.memory_map + if self.storage_options is not None: + kwargs["storage_options"] = self.storage_options + if self.row_count_name is not None: + kwargs["row_count_name"] = self.row_count_name + if self.row_count_offset is not None: + kwargs["row_count_offset"] = self.row_count_offset + if self.rechunk is not None: + kwargs["rechunk"] = self.rechunk + return kwargs + + def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + df = pl.scan_ipc(self.source, **self._get_loading_kwargs()) + metadata = utils.get_file_metadata(self.source) + return df, metadata + + @classmethod + def name(cls) -> str: + return "feather" + + +def register_data_loaders(): + """Function to register the data loaders for this extension.""" + for loader in [ + PolarsScanCSVReader, + PolarsScanParquetReader, + PolarsScanFeatherReader, + ]: + registry.register_adapter(loader) + + +register_data_loaders() diff --git a/tests/plugins/test_polars_extensions.py b/tests/plugins/test_polars_extensions.py index bace4d130..e4db44d13 100644 --- a/tests/plugins/test_polars_extensions.py +++ b/tests/plugins/test_polars_extensions.py @@ -45,7 +45,7 @@ def test_polars_csv(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: kwargs2 = reader._get_loading_kwargs() df2, metadata = reader.load_data(pl.DataFrame) - assert PolarsCSVWriter.applicable_types() == [pl.DataFrame] + assert PolarsCSVWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsCSVReader.applicable_types() == [pl.DataFrame] assert kwargs1["separator"] == "," assert kwargs2["has_header"] is True @@ -63,7 +63,7 @@ def test_polars_parquet(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: kwargs2 = reader._get_loading_kwargs() df2, metadata = reader.load_data(pl.DataFrame) - assert PolarsParquetWriter.applicable_types() == [pl.DataFrame] + assert PolarsParquetWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsParquetReader.applicable_types() == [pl.DataFrame] assert kwargs1["compression"] == "zstd" assert kwargs2["n_rows"] == 2 @@ -85,7 +85,7 @@ def test_polars_feather(tmp_path: pathlib.Path) -> None: assert "n_rows" not in read_kwargs assert df.shape == (4, 3) - assert PolarsFeatherWriter.applicable_types() == [pl.DataFrame] + assert PolarsFeatherWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert "compression" in write_kwargs assert file_path.exists() assert metadata["file_metadata"]["path"] == str(file_path) @@ -103,7 +103,7 @@ def test_polars_json(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: kwargs2 = reader._get_loading_kwargs() df2, metadata = reader.load_data(pl.DataFrame) - assert PolarsJSONWriter.applicable_types() == [pl.DataFrame] + assert PolarsJSONWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsJSONReader.applicable_types() == [pl.DataFrame] assert kwargs1["pretty"] assert df2.shape == (2, 2) @@ -122,7 +122,7 @@ def test_polars_avro(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: kwargs2 = reader._get_loading_kwargs() df2, metadata = reader.load_data(pl.DataFrame) - assert PolarsAvroWriter.applicable_types() == [pl.DataFrame] + assert PolarsAvroWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsAvroReader.applicable_types() == [pl.DataFrame] assert kwargs1["compression"] == "uncompressed" assert kwargs2["n_rows"] == 2 @@ -145,7 +145,7 @@ def test_polars_database(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: kwargs2 = reader._get_loading_kwargs() df2, metadata = reader.load_data(pl.DataFrame) - assert PolarsDatabaseWriter.applicable_types() == [pl.DataFrame] + assert PolarsDatabaseWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsDatabaseReader.applicable_types() == [pl.DataFrame] assert kwargs1["if_table_exists"] == "replace" assert "batch_size" not in kwargs2 @@ -163,7 +163,7 @@ def test_polars_spreadsheet(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: read_kwargs = reader._get_loading_kwargs() df2, _ = reader.load_data(pl.DataFrame) - assert PolarsSpreadsheetWriter.applicable_types() == [pl.DataFrame] + assert PolarsSpreadsheetWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] assert PolarsSpreadsheetReader.applicable_types() == [pl.DataFrame] assert file_path.exists() assert metadata["file_metadata"]["path"] == str(file_path) diff --git a/tests/plugins/test_polars_lazyframe_extensions.py b/tests/plugins/test_polars_lazyframe_extensions.py new file mode 100644 index 000000000..c775ad228 --- /dev/null +++ b/tests/plugins/test_polars_lazyframe_extensions.py @@ -0,0 +1,177 @@ +import pathlib +import sys + +import polars as pl +import pytest +from polars.testing import assert_frame_equal + +from hamilton.plugins.polars_extensions import ( + PolarsAvroReader, + PolarsAvroWriter, + PolarsCSVWriter, + PolarsDatabaseReader, + PolarsDatabaseWriter, + PolarsFeatherWriter, + PolarsJSONReader, + PolarsJSONWriter, + PolarsParquetWriter, + PolarsSpreadsheetReader, + PolarsSpreadsheetWriter, +) +from hamilton.plugins.polars_lazyframe_extensions import ( + PolarsScanCSVReader, + PolarsScanFeatherReader, + PolarsScanParquetReader, +) + + +@pytest.fixture +def df(): + yield pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) + + +def test_lazy_polars_lazyframe_csv(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test.csv" + + writer = PolarsCSVWriter(file=file) + kwargs1 = writer._get_saving_kwargs() + writer.save_data(df) + + reader = PolarsScanCSVReader(file=file) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.LazyFrame) + + assert PolarsCSVWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsScanCSVReader.applicable_types() == [pl.LazyFrame] + assert kwargs1["separator"] == "," + assert kwargs2["has_header"] is True + assert_frame_equal(df.collect(), df2.collect()) + + +def test_lazy_polars_parquet(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test.parquet" + + writer = PolarsParquetWriter(file=file) + kwargs1 = writer._get_saving_kwargs() + writer.save_data(df) + + reader = PolarsScanParquetReader(file=file, n_rows=2) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.LazyFrame) + + assert PolarsParquetWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsScanParquetReader.applicable_types() == [pl.LazyFrame] + assert kwargs1["compression"] == "zstd" + assert kwargs2["n_rows"] == 2 + assert_frame_equal(df.collect(), df2.collect()) + + +def test_lazy_polars_feather(tmp_path: pathlib.Path) -> None: + test_data_file_path = "tests/resources/data/test_load_from_data.feather" + reader = PolarsScanFeatherReader(source=test_data_file_path) + read_kwargs = reader._get_loading_kwargs() + df, _ = reader.load_data(pl.LazyFrame) + + file_path = tmp_path / "test.dta" + writer = PolarsFeatherWriter(file=file_path) + write_kwargs = writer._get_saving_kwargs() + metadata = writer.save_data(df.collect()) + + assert PolarsScanFeatherReader.applicable_types() == [pl.LazyFrame] + assert "n_rows" not in read_kwargs + assert df.collect().shape == (4, 3) + + assert PolarsFeatherWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert "compression" in write_kwargs + assert file_path.exists() + assert metadata["file_metadata"]["path"] == str(file_path) + assert metadata["dataframe_metadata"]["column_names"] == [ + "animal", + "points", + "environment", + ] + assert metadata["dataframe_metadata"]["datatypes"] == ["String", "Int64", "String"] + + +def test_lazy_polars_avro(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test.avro" + + writer = PolarsAvroWriter(file=file) + kwargs1 = writer._get_saving_kwargs() + writer.save_data(df) + + reader = PolarsAvroReader(file=file, n_rows=2) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.DataFrame) + + assert PolarsAvroWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsAvroReader.applicable_types() == [pl.DataFrame] + assert kwargs1["compression"] == "uncompressed" + assert kwargs2["n_rows"] == 2 + assert_frame_equal(df.collect(), df2) + + +def test_polars_json(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test.json" + writer = PolarsJSONWriter(file=file, pretty=True) + kwargs1 = writer._get_saving_kwargs() + writer.save_data(df) + + reader = PolarsJSONReader(source=file) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.DataFrame) + + assert PolarsJSONWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsJSONReader.applicable_types() == [pl.DataFrame] + assert kwargs1["pretty"] + assert df2.shape == (2, 2) + assert "schema" not in kwargs2 + assert_frame_equal(df.collect(), df2) + + +@pytest.mark.skipif( + sys.version_info.major == 3 and sys.version_info.minor == 12, + reason="weird connectorx error on 3.12", +) +def test_polars_database(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + conn = f"sqlite:///{tmp_path}/test.db" + table_name = "test_table" + + writer = PolarsDatabaseWriter(table_name=table_name, connection=conn, if_table_exists="replace") + kwargs1 = writer._get_saving_kwargs() + writer.save_data(df) + + reader = PolarsDatabaseReader(query=f"SELECT * FROM {table_name}", connection=conn) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.DataFrame) + + assert PolarsDatabaseWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsDatabaseReader.applicable_types() == [pl.DataFrame] + assert kwargs1["if_table_exists"] == "replace" + assert "batch_size" not in kwargs2 + assert df2.shape == (2, 2) + assert_frame_equal(df.collect(), df2) + + +def test_polars_spreadsheet(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file_path = tmp_path / "test.xlsx" + writer = PolarsSpreadsheetWriter(workbook=file_path, worksheet="test_load_from_data_sheet") + write_kwargs = writer._get_saving_kwargs() + metadata = writer.save_data(df) + + reader = PolarsSpreadsheetReader(source=file_path, sheet_name="test_load_from_data_sheet") + read_kwargs = reader._get_loading_kwargs() + df2, _ = reader.load_data(pl.DataFrame) + + assert PolarsSpreadsheetWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsSpreadsheetReader.applicable_types() == [pl.DataFrame] + assert file_path.exists() + assert metadata["file_metadata"]["path"] == str(file_path) + assert df.collect().shape == (2, 2) + assert metadata["dataframe_metadata"]["column_names"] == ["a", "b"] + assert metadata["dataframe_metadata"]["datatypes"] == ["Int64", "Int64"] + assert_frame_equal(df.collect(), df2) + assert "include_header" in write_kwargs + assert write_kwargs["include_header"] is True + assert "raise_if_empty" in read_kwargs + assert read_kwargs["raise_if_empty"] is True